diff --git a/eval/compiler/flat_expr_builder.cc b/eval/compiler/flat_expr_builder.cc index 141cabdf1..b51f6b706 100644 --- a/eval/compiler/flat_expr_builder.cc +++ b/eval/compiler/flat_expr_builder.cc @@ -515,6 +515,7 @@ class FlatExprVisitor : public cel::AstVisitor { resolved_select_expr_(nullptr), options_(options), program_optimizers_(std::move(program_optimizers)), + reference_map_(reference_map), issue_collector_(issue_collector), program_builder_(program_builder), extension_context_(extension_context), @@ -1670,6 +1671,15 @@ class FlatExprVisitor : public cel::AstVisitor { suppressed_branches_.insert(expr); } + const cel::Reference& FindReference(const cel::Expr* expr) const { + auto it = reference_map_.find(expr->id()); + if (it == reference_map_.end()) { + static const cel::Reference no_reference; + return no_reference; + } + return it->second; + } + void AddResolvedFunctionStep(const cel::CallExpr* call_expr, const cel::Expr* expr, absl::string_view function) { @@ -1687,12 +1697,14 @@ class FlatExprVisitor : public cel::AstVisitor { auto args = program_builder_.current()->ExtractRecursiveDependencies(); SetRecursiveStep(CreateDirectLazyFunctionStep( expr->id(), *call_expr, std::move(args), - std::move(lazy_overloads)), + std::move(lazy_overloads), + FindReference(expr).overload_id()), *depth + 1); return; } AddStep(CreateFunctionStep(*call_expr, expr->id(), - std::move(lazy_overloads))); + std::move(lazy_overloads), + FindReference(expr).overload_id())); return; } @@ -1721,11 +1733,14 @@ class FlatExprVisitor : public cel::AstVisitor { auto args = program_builder_.current()->ExtractRecursiveDependencies(); SetRecursiveStep( CreateDirectFunctionStep(expr->id(), *call_expr, std::move(args), - std::move(overloads)), + std::move(overloads), + FindReference(expr).overload_id()), *recursion_depth + 1); return; } - AddStep(CreateFunctionStep(*call_expr, expr->id(), std::move(overloads))); + AddStep(CreateFunctionStep(*call_expr, expr->id(), + std::move(overloads), + FindReference(expr).overload_id())); } // Add a step to the program, taking ownership. If successful, returns the @@ -1963,6 +1978,7 @@ class FlatExprVisitor : public cel::AstVisitor { absl::flat_hash_set suppressed_branches_; const cel::Expr* resume_from_suppressed_branch_ = nullptr; std::vector> program_optimizers_; + const absl::flat_hash_map& reference_map_; IssueCollector& issue_collector_; ProgramBuilder& program_builder_; diff --git a/eval/compiler/flat_expr_builder_test.cc b/eval/compiler/flat_expr_builder_test.cc index 262b66e0f..461346965 100644 --- a/eval/compiler/flat_expr_builder_test.cc +++ b/eval/compiler/flat_expr_builder_test.cc @@ -2397,7 +2397,8 @@ class UnknownFunctionImpl : public cel::Function { absl::StatusOr Invoke(absl::Span args, const google::protobuf::DescriptorPool* absl_nonnull, google::protobuf::MessageFactory* absl_nonnull, - google::protobuf::Arena* absl_nonnull) const override { + google::protobuf::Arena* absl_nonnull, + absl::Span overload_id) const override { return cel::UnknownValue(); } }; diff --git a/eval/eval/function_step.cc b/eval/eval/function_step.cc index a860a4bb4..a245989bb 100644 --- a/eval/eval/function_step.cc +++ b/eval/eval/function_step.cc @@ -150,10 +150,11 @@ class AbstractFunctionStep : public ExpressionStepBase { public: // Constructs FunctionStep that uses overloads specified. AbstractFunctionStep(const std::string& name, size_t num_arguments, - int64_t expr_id) + int64_t expr_id, std::vector&& overload_id) : ExpressionStepBase(expr_id), name_(name), - num_arguments_(num_arguments) {} + num_arguments_(num_arguments), + overload_id_(std::move(overload_id)) {} absl::Status Evaluate(ExecutionFrame* frame) const override; @@ -172,15 +173,18 @@ class AbstractFunctionStep : public ExpressionStepBase { protected: std::string name_; size_t num_arguments_; + std::vector overload_id_; }; inline absl::StatusOr Invoke( const cel::FunctionOverloadReference& overload, int64_t expr_id, - absl::Span args, ExecutionFrameBase& frame) { + absl::Span args, ExecutionFrameBase& frame, + absl::Span overload_id) { CEL_ASSIGN_OR_RETURN( Value result, overload.implementation.Invoke(args, frame.descriptor_pool(), - frame.message_factory(), frame.arena())); + frame.message_factory(), frame.arena(), + overload_id)); if (frame.unknown_function_results_enabled() && IsUnknownFunctionResultError(result)) { @@ -240,7 +244,7 @@ absl::StatusOr AbstractFunctionStep::DoEvaluate( // Overload found and is allowed to consume the arguments. if (matched_function.has_value() && ShouldAcceptOverload(matched_function->descriptor, input_args)) { - return Invoke(*matched_function, id(), input_args, *frame); + return Invoke(*matched_function, id(), input_args, *frame, overload_id_); } return NoOverloadResult(name_, input_args, *frame); @@ -323,8 +327,9 @@ absl::StatusOr ResolveLazy( class EagerFunctionStep : public AbstractFunctionStep { public: EagerFunctionStep(std::vector overloads, - const std::string& name, size_t num_args, int64_t expr_id) - : AbstractFunctionStep(name, num_args, expr_id), + const std::string& name, size_t num_args, int64_t expr_id, + std::vector&& overload_id) + : AbstractFunctionStep(name, num_args, expr_id, std::move(overload_id)), overloads_(std::move(overloads)) {} absl::StatusOr ResolveFunction( @@ -344,8 +349,9 @@ class LazyFunctionStep : public AbstractFunctionStep { LazyFunctionStep(const std::string& name, size_t num_args, bool receiver_style, std::vector providers, - int64_t expr_id) - : AbstractFunctionStep(name, num_args, expr_id), + int64_t expr_id, + std::vector&& overload_id) + : AbstractFunctionStep(name, num_args, expr_id, std::move(overload_id)), receiver_style_(receiver_style), providers_(std::move(providers)) {} @@ -404,10 +410,12 @@ class DirectFunctionStepImpl : public DirectExpressionStep { DirectFunctionStepImpl( int64_t expr_id, const std::string& name, std::vector> arg_steps, - Resolver&& resolver) + Resolver&& resolver, + std::vector&& overload_id) : DirectExpressionStep(expr_id), name_(name), arg_steps_(std::move(arg_steps)), + overload_id_(std::move(overload_id)), resolver_(std::forward(resolver)) {} absl::Status Evaluate(ExecutionFrameBase& frame, cel::Value& result, @@ -439,7 +447,8 @@ class DirectFunctionStepImpl : public DirectExpressionStep { if (resolved_function.has_value() && ShouldAcceptOverload(resolved_function->descriptor, args)) { CEL_ASSIGN_OR_RETURN(result, - Invoke(*resolved_function, expr_id_, args, frame)); + Invoke(*resolved_function, expr_id_, args, frame, + overload_id_)); return absl::OkStatus(); } @@ -468,6 +477,7 @@ class DirectFunctionStepImpl : public DirectExpressionStep { friend Resolver; std::string name_; std::vector> arg_steps_; + std::vector overload_id_; Resolver resolver_; }; @@ -476,39 +486,47 @@ class DirectFunctionStepImpl : public DirectExpressionStep { std::unique_ptr CreateDirectFunctionStep( int64_t expr_id, const cel::CallExpr& call, std::vector> deps, - std::vector overloads) { + std::vector overloads, + std::vector overload_id) { return std::make_unique>( expr_id, call.function(), std::move(deps), - StaticResolver(std::move(overloads))); + StaticResolver(std::move(overloads)), + std::move(overload_id)); } std::unique_ptr CreateDirectLazyFunctionStep( int64_t expr_id, const cel::CallExpr& call, std::vector> deps, - std::vector providers) { + std::vector providers, + std::vector overload_id) { return std::make_unique>( expr_id, call.function(), std::move(deps), - LazyResolver(std::move(providers), call.function(), call.has_target())); + LazyResolver(std::move(providers), call.function(), call.has_target()), + std::move(overload_id)); } absl::StatusOr> CreateFunctionStep( const cel::CallExpr& call_expr, int64_t expr_id, - std::vector lazy_overloads) { + std::vector lazy_overloads, + std::vector overload_id) { bool receiver_style = call_expr.has_target(); size_t num_args = call_expr.args().size() + (receiver_style ? 1 : 0); const std::string& name = call_expr.function(); return std::make_unique(name, num_args, receiver_style, - std::move(lazy_overloads), expr_id); + std::move(lazy_overloads), expr_id, + std::move(overload_id)); } absl::StatusOr> CreateFunctionStep( const cel::CallExpr& call_expr, int64_t expr_id, - std::vector overloads) { + std::vector overloads, + std::vector overload_id) { bool receiver_style = call_expr.has_target(); size_t num_args = call_expr.args().size() + (receiver_style ? 1 : 0); const std::string& name = call_expr.function(); return std::make_unique(std::move(overloads), name, - num_args, expr_id); + num_args, expr_id, + std::move(overload_id)); } } // namespace google::api::expr::runtime diff --git a/eval/eval/function_step.h b/eval/eval/function_step.h index 9f664dc09..ae8a12ec9 100644 --- a/eval/eval/function_step.h +++ b/eval/eval/function_step.h @@ -20,7 +20,8 @@ namespace google::api::expr::runtime { std::unique_ptr CreateDirectFunctionStep( int64_t expr_id, const cel::CallExpr& call, std::vector> deps, - std::vector overloads); + std::vector overloads, + std::vector overload_id = {}); // Factory method for Call-based execution step where the function has been // statically resolved from a set of lazy functions configured in the @@ -28,20 +29,23 @@ std::unique_ptr CreateDirectFunctionStep( std::unique_ptr CreateDirectLazyFunctionStep( int64_t expr_id, const cel::CallExpr& call, std::vector> deps, - std::vector providers); + std::vector providers, + std::vector overload_id = {}); // Factory method for Call-based execution step where the function will be // resolved at runtime (lazily) from an input Activation. absl::StatusOr> CreateFunctionStep( const cel::CallExpr& call, int64_t expr_id, - std::vector lazy_overloads); + std::vector lazy_overloads, + std::vector overload_id = {}); // Factory method for Call-based execution step where the function has been // statically resolved from a set of eagerly functions configured in the // CelFunctionRegistry. absl::StatusOr> CreateFunctionStep( const cel::CallExpr& call, int64_t expr_id, - std::vector overloads); + std::vector overloads, + std::vector overload_id = {}); } // namespace google::api::expr::runtime diff --git a/eval/public/cel_function.cc b/eval/public/cel_function.cc index bee39ec8e..7fbad94ee 100644 --- a/eval/public/cel_function.cc +++ b/eval/public/cel_function.cc @@ -57,7 +57,8 @@ absl::StatusOr CelFunction::Invoke( absl::Span arguments, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, - google::protobuf::Arena* absl_nonnull arena) const { + google::protobuf::Arena* absl_nonnull arena, + absl::Span overload_id) const { std::vector legacy_args; legacy_args.reserve(arguments.size()); diff --git a/eval/public/cel_function.h b/eval/public/cel_function.h index c978c6f67..03a9cf691 100644 --- a/eval/public/cel_function.h +++ b/eval/public/cel_function.h @@ -69,7 +69,8 @@ class CelFunction : public ::cel::Function { absl::Span arguments, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, - google::protobuf::Arena* absl_nonnull arena) const override; + google::protobuf::Arena* absl_nonnull arena, + absl::Span overload_id) const override; // CelFunction descriptor const CelFunctionDescriptor& descriptor() const { return descriptor_; } diff --git a/runtime/activation_test.cc b/runtime/activation_test.cc index 30851341a..44264b083 100644 --- a/runtime/activation_test.cc +++ b/runtime/activation_test.cc @@ -69,7 +69,8 @@ class FunctionImpl : public cel::Function { absl::StatusOr Invoke(absl::Span args, const google::protobuf::DescriptorPool* absl_nonnull, google::protobuf::MessageFactory* absl_nonnull, - google::protobuf::Arena* absl_nonnull) const override { + google::protobuf::Arena* absl_nonnull, + absl::Span overload_id) const override { return NullValue(); } }; diff --git a/runtime/function.h b/runtime/function.h index c2a3d257a..c9bb55a7d 100644 --- a/runtime/function.h +++ b/runtime/function.h @@ -47,7 +47,8 @@ class Function { absl::Span args, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, - google::protobuf::Arena* absl_nonnull arena) const = 0; + google::protobuf::Arena* absl_nonnull arena, + absl::Span overload_id = {}) const = 0; }; } // namespace cel diff --git a/runtime/function_adapter.h b/runtime/function_adapter.h index 1c96a6ea1..6e3907d7d 100644 --- a/runtime/function_adapter.h +++ b/runtime/function_adapter.h @@ -231,7 +231,8 @@ class NullaryFunctionAdapter absl::Span args, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, - google::protobuf::Arena* absl_nonnull arena) const override { + google::protobuf::Arena* absl_nonnull arena, + absl::Span overload_id) const override { if (args.size() != 0) { return absl::InvalidArgumentError( "unexpected number of arguments for nullary function"); @@ -316,7 +317,8 @@ class UnaryFunctionAdapter : public RegisterHelper> { absl::Span args, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, - google::protobuf::Arena* absl_nonnull arena) const override { + google::protobuf::Arena* absl_nonnull arena, + absl::Span overload_id) const override { using ArgTraits = runtime_internal::AdaptedTypeTraits; if (args.size() != 1) { return absl::InvalidArgumentError( @@ -456,7 +458,8 @@ class BinaryFunctionAdapter absl::Span args, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, - google::protobuf::Arena* absl_nonnull arena) const override { + google::protobuf::Arena* absl_nonnull arena, + absl::Span overload_id) const override { using Arg1Traits = runtime_internal::AdaptedTypeTraits; using Arg2Traits = runtime_internal::AdaptedTypeTraits; if (args.size() != 2) { @@ -537,7 +540,8 @@ class TernaryFunctionAdapter absl::Span args, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, - google::protobuf::Arena* absl_nonnull arena) const override { + google::protobuf::Arena* absl_nonnull arena, + absl::Span overload_id) const override { using Arg1Traits = runtime_internal::AdaptedTypeTraits; using Arg2Traits = runtime_internal::AdaptedTypeTraits; using Arg3Traits = runtime_internal::AdaptedTypeTraits; @@ -624,7 +628,8 @@ class QuaternaryFunctionAdapter absl::Span args, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, - google::protobuf::Arena* absl_nonnull arena) const override { + google::protobuf::Arena* absl_nonnull arena, + absl::Span overload_id) const override { using Arg1Traits = runtime_internal::AdaptedTypeTraits; using Arg2Traits = runtime_internal::AdaptedTypeTraits; using Arg3Traits = runtime_internal::AdaptedTypeTraits; diff --git a/runtime/function_registry_test.cc b/runtime/function_registry_test.cc index 99b5ec406..064296143 100644 --- a/runtime/function_registry_test.cc +++ b/runtime/function_registry_test.cc @@ -54,7 +54,8 @@ class ConstIntFunction : public cel::Function { absl::Span args, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, - google::protobuf::Arena* absl_nonnull arena) const override { + google::protobuf::Arena* absl_nonnull arena, + absl::Span overload_id) const override { return IntValue(42); } }; diff --git a/runtime/optional_types_test.cc b/runtime/optional_types_test.cc index 1f118b639..3573a8648 100644 --- a/runtime/optional_types_test.cc +++ b/runtime/optional_types_test.cc @@ -310,7 +310,8 @@ class UnreachableFunction final : public cel::Function { absl::Span args, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, - google::protobuf::Arena* absl_nonnull arena) const override { + google::protobuf::Arena* absl_nonnull arena, + absl::Span overload_id) const override { ++(*count_); return ErrorValue{absl::CancelledError()}; }