diff --git a/eval/compiler/flat_expr_builder_test.cc b/eval/compiler/flat_expr_builder_test.cc index 262b66e0f..0dba74a29 100644 --- a/eval/compiler/flat_expr_builder_test.cc +++ b/eval/compiler/flat_expr_builder_test.cc @@ -2395,9 +2395,7 @@ struct ConstantFoldingTestCase { class UnknownFunctionImpl : public cel::Function { absl::StatusOr Invoke(absl::Span args, - const google::protobuf::DescriptorPool* absl_nonnull, - google::protobuf::MessageFactory* absl_nonnull, - google::protobuf::Arena* absl_nonnull) const override { + const InvokeContext& context) const override { return cel::UnknownValue(); } }; diff --git a/eval/public/BUILD b/eval/public/BUILD index 4e25c0481..31ad2d480 100644 --- a/eval/public/BUILD +++ b/eval/public/BUILD @@ -198,7 +198,6 @@ cc_library( "//eval/internal:interop", "//internal:status_macros", "//runtime:function", - "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/types:span", diff --git a/eval/public/cel_function.cc b/eval/public/cel_function.cc index 10d0fd798..9b760d1ec 100644 --- a/eval/public/cel_function.cc +++ b/eval/public/cel_function.cc @@ -3,7 +3,6 @@ #include #include -#include "absl/base/nullability.h" #include "absl/status/statusor.h" #include "absl/types/span.h" #include "common/value.h" @@ -11,9 +10,6 @@ #include "eval/public/cel_value.h" #include "internal/status_macros.h" #include "runtime/function.h" -#include "google/protobuf/arena.h" -#include "google/protobuf/descriptor.h" -#include "google/protobuf/message.h" namespace google::api::expr::runtime { @@ -56,9 +52,7 @@ bool CelFunction::MatchArguments(absl::Span arguments) const { absl::StatusOr CelFunction::Invoke( absl::Span arguments, - const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, - google::protobuf::MessageFactory* absl_nonnull message_factory, - google::protobuf::Arena* absl_nonnull arena) const { + const cel::Function::InvokeContext& context) const { std::vector legacy_args; legacy_args.reserve(arguments.size()); @@ -68,22 +62,15 @@ absl::StatusOr CelFunction::Invoke( // interpreter expects to only be used with internal program steps. for (const auto& arg : arguments) { CEL_ASSIGN_OR_RETURN(legacy_args.emplace_back(), - ToLegacyValue(arena, arg, true)); + ToLegacyValue(context.arena(), arg, true)); } CelValue legacy_result; - CEL_RETURN_IF_ERROR(Evaluate(legacy_args, &legacy_result, arena)); + CEL_RETURN_IF_ERROR(Evaluate(legacy_args, &legacy_result, context.arena())); return cel::interop_internal::LegacyValueToModernValueOrDie( - arena, legacy_result, /*unchecked=*/true); -} - -absl::StatusOr CelFunction::Invoke( - absl::Span arguments, - const cel::Function::InvokeContext& context) const { - return CelFunction::Invoke(arguments, context.descriptor_pool(), - context.message_factory(), context.arena()); + context.arena(), legacy_result, /*unchecked=*/true); } } // namespace google::api::expr::runtime diff --git a/eval/public/cel_function.h b/eval/public/cel_function.h index d2a8fd2cf..6c9ff2e7a 100644 --- a/eval/public/cel_function.h +++ b/eval/public/cel_function.h @@ -3,7 +3,6 @@ #include -#include "absl/base/nullability.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/types/span.h" @@ -12,8 +11,6 @@ #include "eval/public/cel_value.h" #include "runtime/function.h" #include "google/protobuf/arena.h" -#include "google/protobuf/descriptor.h" -#include "google/protobuf/message.h" namespace google::api::expr::runtime { @@ -65,11 +62,8 @@ class CelFunction : public ::cel::Function { bool MatchArguments(absl::Span arguments) const; // Implements cel::Function. - absl::StatusOr Invoke( - absl::Span arguments, - const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, - google::protobuf::MessageFactory* absl_nonnull message_factory, - google::protobuf::Arena* absl_nonnull arena) const final; + using cel::Function::Invoke; + absl::StatusOr Invoke( absl::Span arguments, const cel::Function::InvokeContext& context) const final; diff --git a/runtime/BUILD b/runtime/BUILD index 53c6174bf..87d8bb25c 100644 --- a/runtime/BUILD +++ b/runtime/BUILD @@ -156,10 +156,11 @@ cc_test( ":function_registry", "//common:function_descriptor", "//common:kind", + "//common:value", "//internal:testing", - "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/status", - "@com_google_protobuf//:protobuf", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/types:span", ], ) @@ -505,13 +506,11 @@ cc_library( ":function", ":register_function_helper", "//common:function_descriptor", - "//common:kind", "//common:value", "//internal:status_macros", "//runtime/internal:function_adapter", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/functional:any_invocable", - "@com_google_absl//absl/functional:bind_front", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", @@ -588,7 +587,6 @@ cc_test( "//parser", "//parser:options", "//runtime/internal:runtime_impl", - "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/status", "@com_google_absl//absl/status:status_matchers", "@com_google_absl//absl/status:statusor", diff --git a/runtime/activation_test.cc b/runtime/activation_test.cc index 30851341a..e6a74f027 100644 --- a/runtime/activation_test.cc +++ b/runtime/activation_test.cc @@ -67,9 +67,7 @@ class FunctionImpl : public cel::Function { FunctionImpl() = default; absl::StatusOr Invoke(absl::Span args, - const google::protobuf::DescriptorPool* absl_nonnull, - google::protobuf::MessageFactory* absl_nonnull, - google::protobuf::Arena* absl_nonnull) const override { + const InvokeContext& context) const override { return NullValue(); } }; diff --git a/runtime/function.h b/runtime/function.h index b89d421d7..6ab1e4a7d 100644 --- a/runtime/function.h +++ b/runtime/function.h @@ -64,6 +64,13 @@ class Function { google::protobuf::Arena* absl_nonnull arena_; }; + ABSL_DEPRECATED("Use the InvokeContext overload instead.") + inline absl::StatusOr Invoke( + absl::Span args, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena) const; + // Attempt to evaluate an extension function based on the runtime arguments // during the evaluation of a CEL expression. // @@ -72,18 +79,19 @@ class Function { // // A cel::ErrorValue typed result is considered a recoverable error and // follows CEL's logical short-circuiting behavior. - virtual absl::StatusOr Invoke( - absl::Span args, - const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, - google::protobuf::MessageFactory* absl_nonnull message_factory, - google::protobuf::Arena* absl_nonnull arena) const = 0; virtual absl::StatusOr Invoke(absl::Span args, - const InvokeContext& context) const { - return Invoke(args, context.descriptor_pool(), context.message_factory(), - context.arena()); - } + const InvokeContext& context) const = 0; }; +absl::StatusOr Function::Invoke( + absl::Span args, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena) const { + InvokeContext context(descriptor_pool, message_factory, arena); + return Invoke(args, context); +} + } // namespace cel #endif // THIRD_PARTY_CEL_CPP_COMMON_FUNCTION_H_ diff --git a/runtime/function_adapter.h b/runtime/function_adapter.h index f899c497e..62932a027 100644 --- a/runtime/function_adapter.h +++ b/runtime/function_adapter.h @@ -159,22 +159,6 @@ struct ToArgsHelper { } }; -class FunctionAdapterBase : public Function { - public: - using Function::Invoke; - - // Should not be called by CEL, but added for backward compatibility for - // client code tests. - absl::StatusOr Invoke( - absl::Span args, - const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, - google::protobuf::MessageFactory* absl_nonnull message_factory, - google::protobuf::Arena* absl_nonnull arena) const final { - Function::InvokeContext context(descriptor_pool, message_factory, arena); - return Invoke(args, context); - } -}; - } // namespace runtime_internal // Adapter class for generating CEL extension functions from a one argument @@ -247,12 +231,12 @@ class NullaryFunctionAdapter } private: - class UnaryFunctionImpl : public runtime_internal::FunctionAdapterBase { + class UnaryFunctionImpl : public Function { public: explicit UnaryFunctionImpl(FunctionType fn) : fn_(std::move(fn)) {} absl::StatusOr Invoke( absl::Span args, - const Function::InvokeContext& context) const override { + const Function::InvokeContext& context) const final { if (args.size() != 0) { return absl::InvalidArgumentError( "unexpected number of arguments for nullary function"); @@ -346,12 +330,12 @@ class UnaryFunctionAdapter : public RegisterHelper> { } private: - class UnaryFunctionImpl : public runtime_internal::FunctionAdapterBase { + class UnaryFunctionImpl : public Function { public: explicit UnaryFunctionImpl(FunctionType fn) : fn_(std::move(fn)) {} absl::StatusOr Invoke( absl::Span args, - const Function::InvokeContext& context) const override { + const Function::InvokeContext& context) const final { using ArgTraits = runtime_internal::AdaptedTypeTraits; if (args.size() != 1) { return absl::InvalidArgumentError( @@ -497,12 +481,12 @@ class BinaryFunctionAdapter } private: - class BinaryFunctionImpl : public runtime_internal::FunctionAdapterBase { + class BinaryFunctionImpl : public Function { public: explicit BinaryFunctionImpl(FunctionType fn) : fn_(std::move(fn)) {} absl::StatusOr Invoke( absl::Span args, - const Function::InvokeContext& context) const override { + const Function::InvokeContext& context) const final { using Arg1Traits = runtime_internal::AdaptedTypeTraits; using Arg2Traits = runtime_internal::AdaptedTypeTraits; if (args.size() != 2) { @@ -588,12 +572,12 @@ class TernaryFunctionAdapter } private: - class TernaryFunctionImpl : public runtime_internal::FunctionAdapterBase { + class TernaryFunctionImpl : public Function { public: explicit TernaryFunctionImpl(FunctionType fn) : fn_(std::move(fn)) {} absl::StatusOr Invoke( absl::Span args, - const Function::InvokeContext& context) const override { + const Function::InvokeContext& context) const final { using Arg1Traits = runtime_internal::AdaptedTypeTraits; using Arg2Traits = runtime_internal::AdaptedTypeTraits; using Arg3Traits = runtime_internal::AdaptedTypeTraits; @@ -684,12 +668,12 @@ class QuaternaryFunctionAdapter } private: - class QuaternaryFunctionImpl : public runtime_internal::FunctionAdapterBase { + class QuaternaryFunctionImpl : public Function { public: explicit QuaternaryFunctionImpl(FunctionType fn) : fn_(std::move(fn)) {} absl::StatusOr Invoke( absl::Span args, - const Function::InvokeContext& context) const override { + const Function::InvokeContext& context) const final { using Arg1Traits = runtime_internal::AdaptedTypeTraits; using Arg2Traits = runtime_internal::AdaptedTypeTraits; using Arg3Traits = runtime_internal::AdaptedTypeTraits; @@ -807,7 +791,7 @@ class NaryFunctionAdapter } private: - class NaryFunctionImpl : public runtime_internal::FunctionAdapterBase { + class NaryFunctionImpl : public Function { private: using ArgBuffer = std::tuple< typename runtime_internal::AdaptedTypeTraits::AssignableType...>; @@ -816,7 +800,7 @@ class NaryFunctionAdapter explicit NaryFunctionImpl(FunctionType fn) : fn_(std::move(fn)) {} absl::StatusOr Invoke( absl::Span args, - const Function::InvokeContext& context) const override { + const Function::InvokeContext& context) const final { if (args.size() != sizeof...(Args)) { return absl::InvalidArgumentError( absl::StrCat("unexpected number of arguments for ", sizeof...(Args), diff --git a/runtime/function_registry_test.cc b/runtime/function_registry_test.cc index 99b5ec406..af7f5bc06 100644 --- a/runtime/function_registry_test.cc +++ b/runtime/function_registry_test.cc @@ -19,19 +19,18 @@ #include #include -#include "absl/base/nullability.h" #include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/types/span.h" #include "common/function_descriptor.h" #include "common/kind.h" +#include "common/value.h" #include "internal/testing.h" #include "runtime/activation.h" #include "runtime/function.h" #include "runtime/function_adapter.h" #include "runtime/function_overload_reference.h" #include "runtime/function_provider.h" -#include "google/protobuf/arena.h" -#include "google/protobuf/descriptor.h" -#include "google/protobuf/message.h" namespace cel { @@ -50,11 +49,8 @@ class ConstIntFunction : public cel::Function { return {"ConstFunction", false, {}}; } - absl::StatusOr Invoke( - absl::Span args, - const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, - google::protobuf::MessageFactory* absl_nonnull message_factory, - google::protobuf::Arena* absl_nonnull arena) const override { + absl::StatusOr Invoke(absl::Span args, + const InvokeContext& context) const override { return IntValue(42); } }; diff --git a/runtime/optional_types_test.cc b/runtime/optional_types_test.cc index 1f118b639..07029732f 100644 --- a/runtime/optional_types_test.cc +++ b/runtime/optional_types_test.cc @@ -22,7 +22,6 @@ #include #include "cel/expr/syntax.pb.h" -#include "absl/base/nullability.h" #include "absl/status/status.h" #include "absl/status/status_matchers.h" #include "absl/status/statusor.h" @@ -45,8 +44,6 @@ #include "runtime/runtime_options.h" #include "runtime/standard_runtime_builder_factory.h" #include "google/protobuf/arena.h" -#include "google/protobuf/descriptor.h" -#include "google/protobuf/message.h" namespace cel::extensions { namespace { @@ -306,13 +303,10 @@ class UnreachableFunction final : public cel::Function { public: explicit UnreachableFunction(int64_t* count) : count_(count) {} - absl::StatusOr Invoke( - absl::Span args, - const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, - google::protobuf::MessageFactory* absl_nonnull message_factory, - google::protobuf::Arena* absl_nonnull arena) const override { + absl::StatusOr Invoke(absl::Span args, + const InvokeContext& context) const override { ++(*count_); - return ErrorValue{absl::CancelledError()}; + return ErrorValue(absl::CancelledError()); } private: