Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions eval/compiler/flat_expr_builder_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2395,9 +2395,7 @@ struct ConstantFoldingTestCase {

class UnknownFunctionImpl : public cel::Function {
absl::StatusOr<Value> Invoke(absl::Span<const Value> args,
const google::protobuf::DescriptorPool* absl_nonnull,
google::protobuf::MessageFactory* absl_nonnull,
google::protobuf::Arena* absl_nonnull) const override {
const InvokeContext& context) const override {
return cel::UnknownValue();
}
};
Expand Down
1 change: 0 additions & 1 deletion eval/public/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,6 @@ cc_library(
"//eval/internal:interop",
"//internal:status_macros",
"//runtime:function",
"@com_google_absl//absl/base:nullability",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/types:span",
Expand Down
21 changes: 4 additions & 17 deletions eval/public/cel_function.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,13 @@
#include <cstddef>
#include <vector>

#include "absl/base/nullability.h"
#include "absl/status/statusor.h"
#include "absl/types/span.h"
#include "common/value.h"
#include "eval/internal/interop.h"
#include "eval/public/cel_value.h"
#include "internal/status_macros.h"
#include "runtime/function.h"
#include "google/protobuf/arena.h"
#include "google/protobuf/descriptor.h"
#include "google/protobuf/message.h"

namespace google::api::expr::runtime {

Expand Down Expand Up @@ -56,9 +52,7 @@ bool CelFunction::MatchArguments(absl::Span<const cel::Value> arguments) const {

absl::StatusOr<Value> CelFunction::Invoke(
absl::Span<const cel::Value> arguments,
const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool,
google::protobuf::MessageFactory* absl_nonnull message_factory,
google::protobuf::Arena* absl_nonnull arena) const {
const cel::Function::InvokeContext& context) const {
std::vector<CelValue> legacy_args;
legacy_args.reserve(arguments.size());

Expand All @@ -68,22 +62,15 @@ absl::StatusOr<Value> CelFunction::Invoke(
// interpreter expects to only be used with internal program steps.
for (const auto& arg : arguments) {
CEL_ASSIGN_OR_RETURN(legacy_args.emplace_back(),
ToLegacyValue(arena, arg, true));
ToLegacyValue(context.arena(), arg, true));
}

CelValue legacy_result;

CEL_RETURN_IF_ERROR(Evaluate(legacy_args, &legacy_result, arena));
CEL_RETURN_IF_ERROR(Evaluate(legacy_args, &legacy_result, context.arena()));

return cel::interop_internal::LegacyValueToModernValueOrDie(
arena, legacy_result, /*unchecked=*/true);
}

absl::StatusOr<Value> CelFunction::Invoke(
absl::Span<const cel::Value> arguments,
const cel::Function::InvokeContext& context) const {
return CelFunction::Invoke(arguments, context.descriptor_pool(),
context.message_factory(), context.arena());
context.arena(), legacy_result, /*unchecked=*/true);
}

} // namespace google::api::expr::runtime
10 changes: 2 additions & 8 deletions eval/public/cel_function.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

#include <utility>

#include "absl/base/nullability.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/types/span.h"
Expand All @@ -12,8 +11,6 @@
#include "eval/public/cel_value.h"
#include "runtime/function.h"
#include "google/protobuf/arena.h"
#include "google/protobuf/descriptor.h"
#include "google/protobuf/message.h"

namespace google::api::expr::runtime {

Expand Down Expand Up @@ -65,11 +62,8 @@ class CelFunction : public ::cel::Function {
bool MatchArguments(absl::Span<const cel::Value> arguments) const;

// Implements cel::Function.
absl::StatusOr<cel::Value> Invoke(
absl::Span<const cel::Value> arguments,
const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool,
google::protobuf::MessageFactory* absl_nonnull message_factory,
google::protobuf::Arena* absl_nonnull arena) const final;
using cel::Function::Invoke;

absl::StatusOr<cel::Value> Invoke(
absl::Span<const cel::Value> arguments,
const cel::Function::InvokeContext& context) const final;
Expand Down
8 changes: 3 additions & 5 deletions runtime/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -156,10 +156,11 @@ cc_test(
":function_registry",
"//common:function_descriptor",
"//common:kind",
"//common:value",
"//internal:testing",
"@com_google_absl//absl/base:nullability",
"@com_google_absl//absl/status",
"@com_google_protobuf//:protobuf",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/types:span",
],
)

Expand Down Expand Up @@ -505,13 +506,11 @@ cc_library(
":function",
":register_function_helper",
"//common:function_descriptor",
"//common:kind",
"//common:value",
"//internal:status_macros",
"//runtime/internal:function_adapter",
"@com_google_absl//absl/base:nullability",
"@com_google_absl//absl/functional:any_invocable",
"@com_google_absl//absl/functional:bind_front",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
Expand Down Expand Up @@ -588,7 +587,6 @@ cc_test(
"//parser",
"//parser:options",
"//runtime/internal:runtime_impl",
"@com_google_absl//absl/base:nullability",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:status_matchers",
"@com_google_absl//absl/status:statusor",
Expand Down
4 changes: 1 addition & 3 deletions runtime/activation_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,7 @@ class FunctionImpl : public cel::Function {
FunctionImpl() = default;

absl::StatusOr<Value> Invoke(absl::Span<const Value> args,
const google::protobuf::DescriptorPool* absl_nonnull,
google::protobuf::MessageFactory* absl_nonnull,
google::protobuf::Arena* absl_nonnull) const override {
const InvokeContext& context) const override {
return NullValue();
}
};
Expand Down
26 changes: 17 additions & 9 deletions runtime/function.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,13 @@ class Function {
google::protobuf::Arena* absl_nonnull arena_;
};

ABSL_DEPRECATED("Use the InvokeContext overload instead.")
inline absl::StatusOr<Value> Invoke(
absl::Span<const Value> args,
const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool,
google::protobuf::MessageFactory* absl_nonnull message_factory,
google::protobuf::Arena* absl_nonnull arena) const;

// Attempt to evaluate an extension function based on the runtime arguments
// during the evaluation of a CEL expression.
//
Expand All @@ -72,18 +79,19 @@ class Function {
//
// A cel::ErrorValue typed result is considered a recoverable error and
// follows CEL's logical short-circuiting behavior.
virtual absl::StatusOr<Value> Invoke(
absl::Span<const Value> args,
const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool,
google::protobuf::MessageFactory* absl_nonnull message_factory,
google::protobuf::Arena* absl_nonnull arena) const = 0;
virtual absl::StatusOr<Value> Invoke(absl::Span<const Value> args,
const InvokeContext& context) const {
return Invoke(args, context.descriptor_pool(), context.message_factory(),
context.arena());
}
const InvokeContext& context) const = 0;
};

absl::StatusOr<Value> Function::Invoke(
absl::Span<const Value> args,
const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool,
google::protobuf::MessageFactory* absl_nonnull message_factory,
google::protobuf::Arena* absl_nonnull arena) const {
InvokeContext context(descriptor_pool, message_factory, arena);
return Invoke(args, context);
}

} // namespace cel

#endif // THIRD_PARTY_CEL_CPP_COMMON_FUNCTION_H_
40 changes: 12 additions & 28 deletions runtime/function_adapter.h
Original file line number Diff line number Diff line change
Expand Up @@ -159,22 +159,6 @@ struct ToArgsHelper {
}
};

class FunctionAdapterBase : public Function {
public:
using Function::Invoke;

// Should not be called by CEL, but added for backward compatibility for
// client code tests.
absl::StatusOr<Value> Invoke(
absl::Span<const Value> args,
const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool,
google::protobuf::MessageFactory* absl_nonnull message_factory,
google::protobuf::Arena* absl_nonnull arena) const final {
Function::InvokeContext context(descriptor_pool, message_factory, arena);
return Invoke(args, context);
}
};

} // namespace runtime_internal

// Adapter class for generating CEL extension functions from a one argument
Expand Down Expand Up @@ -247,12 +231,12 @@ class NullaryFunctionAdapter
}

private:
class UnaryFunctionImpl : public runtime_internal::FunctionAdapterBase {
class UnaryFunctionImpl : public Function {
public:
explicit UnaryFunctionImpl(FunctionType fn) : fn_(std::move(fn)) {}
absl::StatusOr<Value> Invoke(
absl::Span<const Value> args,
const Function::InvokeContext& context) const override {
const Function::InvokeContext& context) const final {
if (args.size() != 0) {
return absl::InvalidArgumentError(
"unexpected number of arguments for nullary function");
Expand Down Expand Up @@ -346,12 +330,12 @@ class UnaryFunctionAdapter : public RegisterHelper<UnaryFunctionAdapter<T, U>> {
}

private:
class UnaryFunctionImpl : public runtime_internal::FunctionAdapterBase {
class UnaryFunctionImpl : public Function {
public:
explicit UnaryFunctionImpl(FunctionType fn) : fn_(std::move(fn)) {}
absl::StatusOr<Value> Invoke(
absl::Span<const Value> args,
const Function::InvokeContext& context) const override {
const Function::InvokeContext& context) const final {
using ArgTraits = runtime_internal::AdaptedTypeTraits<U>;
if (args.size() != 1) {
return absl::InvalidArgumentError(
Expand Down Expand Up @@ -497,12 +481,12 @@ class BinaryFunctionAdapter
}

private:
class BinaryFunctionImpl : public runtime_internal::FunctionAdapterBase {
class BinaryFunctionImpl : public Function {
public:
explicit BinaryFunctionImpl(FunctionType fn) : fn_(std::move(fn)) {}
absl::StatusOr<Value> Invoke(
absl::Span<const Value> args,
const Function::InvokeContext& context) const override {
const Function::InvokeContext& context) const final {
using Arg1Traits = runtime_internal::AdaptedTypeTraits<U>;
using Arg2Traits = runtime_internal::AdaptedTypeTraits<V>;
if (args.size() != 2) {
Expand Down Expand Up @@ -588,12 +572,12 @@ class TernaryFunctionAdapter
}

private:
class TernaryFunctionImpl : public runtime_internal::FunctionAdapterBase {
class TernaryFunctionImpl : public Function {
public:
explicit TernaryFunctionImpl(FunctionType fn) : fn_(std::move(fn)) {}
absl::StatusOr<Value> Invoke(
absl::Span<const Value> args,
const Function::InvokeContext& context) const override {
const Function::InvokeContext& context) const final {
using Arg1Traits = runtime_internal::AdaptedTypeTraits<U>;
using Arg2Traits = runtime_internal::AdaptedTypeTraits<V>;
using Arg3Traits = runtime_internal::AdaptedTypeTraits<W>;
Expand Down Expand Up @@ -684,12 +668,12 @@ class QuaternaryFunctionAdapter
}

private:
class QuaternaryFunctionImpl : public runtime_internal::FunctionAdapterBase {
class QuaternaryFunctionImpl : public Function {
public:
explicit QuaternaryFunctionImpl(FunctionType fn) : fn_(std::move(fn)) {}
absl::StatusOr<Value> Invoke(
absl::Span<const Value> args,
const Function::InvokeContext& context) const override {
const Function::InvokeContext& context) const final {
using Arg1Traits = runtime_internal::AdaptedTypeTraits<U>;
using Arg2Traits = runtime_internal::AdaptedTypeTraits<V>;
using Arg3Traits = runtime_internal::AdaptedTypeTraits<W>;
Expand Down Expand Up @@ -807,7 +791,7 @@ class NaryFunctionAdapter
}

private:
class NaryFunctionImpl : public runtime_internal::FunctionAdapterBase {
class NaryFunctionImpl : public Function {
private:
using ArgBuffer = std::tuple<
typename runtime_internal::AdaptedTypeTraits<Args>::AssignableType...>;
Expand All @@ -816,7 +800,7 @@ class NaryFunctionAdapter
explicit NaryFunctionImpl(FunctionType fn) : fn_(std::move(fn)) {}
absl::StatusOr<Value> Invoke(
absl::Span<const Value> args,
const Function::InvokeContext& context) const override {
const Function::InvokeContext& context) const final {
if (args.size() != sizeof...(Args)) {
return absl::InvalidArgumentError(
absl::StrCat("unexpected number of arguments for ", sizeof...(Args),
Expand Down
14 changes: 5 additions & 9 deletions runtime/function_registry_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,19 +19,18 @@
#include <tuple>
#include <vector>

#include "absl/base/nullability.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/types/span.h"
#include "common/function_descriptor.h"
#include "common/kind.h"
#include "common/value.h"
#include "internal/testing.h"
#include "runtime/activation.h"
#include "runtime/function.h"
#include "runtime/function_adapter.h"
#include "runtime/function_overload_reference.h"
#include "runtime/function_provider.h"
#include "google/protobuf/arena.h"
#include "google/protobuf/descriptor.h"
#include "google/protobuf/message.h"

namespace cel {

Expand All @@ -50,11 +49,8 @@ class ConstIntFunction : public cel::Function {
return {"ConstFunction", false, {}};
}

absl::StatusOr<Value> Invoke(
absl::Span<const Value> args,
const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool,
google::protobuf::MessageFactory* absl_nonnull message_factory,
google::protobuf::Arena* absl_nonnull arena) const override {
absl::StatusOr<Value> Invoke(absl::Span<const Value> args,
const InvokeContext& context) const override {
return IntValue(42);
}
};
Expand Down
12 changes: 3 additions & 9 deletions runtime/optional_types_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
#include <vector>

#include "cel/expr/syntax.pb.h"
#include "absl/base/nullability.h"
#include "absl/status/status.h"
#include "absl/status/status_matchers.h"
#include "absl/status/statusor.h"
Expand All @@ -45,8 +44,6 @@
#include "runtime/runtime_options.h"
#include "runtime/standard_runtime_builder_factory.h"
#include "google/protobuf/arena.h"
#include "google/protobuf/descriptor.h"
#include "google/protobuf/message.h"

namespace cel::extensions {
namespace {
Expand Down Expand Up @@ -306,13 +303,10 @@ class UnreachableFunction final : public cel::Function {
public:
explicit UnreachableFunction(int64_t* count) : count_(count) {}

absl::StatusOr<Value> Invoke(
absl::Span<const Value> args,
const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool,
google::protobuf::MessageFactory* absl_nonnull message_factory,
google::protobuf::Arena* absl_nonnull arena) const override {
absl::StatusOr<Value> Invoke(absl::Span<const Value> args,
const InvokeContext& context) const override {
++(*count_);
return ErrorValue{absl::CancelledError()};
return ErrorValue(absl::CancelledError());
}

private:
Expand Down