From 01f0a03d25cc4a5a62ec6fb825f1269e2dfe2ddd Mon Sep 17 00:00:00 2001 From: Niko Huurre Date: Fri, 19 Feb 2021 20:02:31 +0200 Subject: [PATCH 01/25] add closure support and a few tests --- stan/math/prim/err/elementwise_check.hpp | 7 + stan/math/prim/fun/value_of.hpp | 13 ++ stan/math/prim/functor.hpp | 1 + stan/math/prim/functor/closure_adapter.hpp | 134 ++++++++++++++++++ stan/math/prim/functor/integrate_ode_rk45.hpp | 2 +- ...grate_ode_std_vector_interface_adapter.hpp | 52 ++++++- stan/math/prim/functor/ode_rk45.hpp | 23 ++- stan/math/prim/meta.hpp | 1 + stan/math/prim/meta/is_stan_closure.hpp | 76 ++++++++++ stan/math/prim/meta/return_type.hpp | 4 +- stan/math/rev/core/accumulate_adjoints.hpp | 25 ++++ stan/math/rev/core/count_vars.hpp | 24 ++++ stan/math/rev/core/deep_copy_vars.hpp | 13 ++ stan/math/rev/core/save_varis.hpp | 23 +++ stan/math/rev/core/zero_adjoints.hpp | 21 +++ stan/math/rev/functor/integrate_ode_adams.hpp | 4 +- stan/math/rev/functor/integrate_ode_bdf.hpp | 4 +- stan/math/rev/functor/ode_adams.hpp | 19 ++- stan/math/rev/functor/ode_bdf.hpp | 19 ++- .../functor/integrate_ode_rk45_rev_test.cpp | 64 +++++++++ .../math/rev/functor/ode_rk45_rev_test.cpp | 80 +++++++++++ .../rev/functor/reduce_sum_closure_test.cpp | 83 +++++++++++ 22 files changed, 676 insertions(+), 16 deletions(-) create mode 100644 stan/math/prim/functor/closure_adapter.hpp create mode 100644 stan/math/prim/meta/is_stan_closure.hpp create mode 100644 test/unit/math/rev/functor/reduce_sum_closure_test.cpp diff --git a/stan/math/prim/err/elementwise_check.hpp b/stan/math/prim/err/elementwise_check.hpp index de2e192d41d..4c4386c1f06 100644 --- a/stan/math/prim/err/elementwise_check.hpp +++ b/stan/math/prim/err/elementwise_check.hpp @@ -124,6 +124,13 @@ inline void elementwise_check(const F& is_good, const char* function, }(); } } +template * = nullptr> +inline void elementwise_check(const F& is_good, const char* function, + const char* name, const T& x, const char* must_be, + const Indexings&... indexings) { + // XXX skip closures +} /** * Check that the predicate holds for all elements of the value of `x`. This * overload works on Eigen types that support linear indexing. diff --git a/stan/math/prim/fun/value_of.hpp b/stan/math/prim/fun/value_of.hpp index b2593a856a6..73d4ac69e9a 100644 --- a/stan/math/prim/fun/value_of.hpp +++ b/stan/math/prim/fun/value_of.hpp @@ -69,6 +69,19 @@ inline auto value_of(EigMat&& M) { std::forward(M)); } +/** + * Closures that capture non-arithmetic types have value_of__() method. + * + * @tparam F Input element type + * @param[in] f Input closure + * @return closure + **/ +template * = nullptr, + require_not_st_arithmetic* = nullptr> +inline auto value_of(const F& f) { + return f.value_of__(); +} + } // namespace math } // namespace stan diff --git a/stan/math/prim/functor.hpp b/stan/math/prim/functor.hpp index 5568fe9cce9..5def74f0cad 100644 --- a/stan/math/prim/functor.hpp +++ b/stan/math/prim/functor.hpp @@ -6,6 +6,7 @@ #include #include #include +#include #include #include #include diff --git a/stan/math/prim/functor/closure_adapter.hpp b/stan/math/prim/functor/closure_adapter.hpp new file mode 100644 index 00000000000..378b5390395 --- /dev/null +++ b/stan/math/prim/functor/closure_adapter.hpp @@ -0,0 +1,134 @@ +#ifndef STAN_MATH_PRIM_FUNCTOR_CLOSURE_ADAPTER_HPP +#define STAN_MATH_PRIM_FUNCTOR_CLOSURE_ADAPTER_HPP + +#include +#include + +namespace stan { +namespace math { + +template +struct empty_closure { + using captured_scalar_t__ = double; + using ValueOf__ = empty_closure; + using CopyOf__ = empty_closure; + static const size_t vars_count__ = 0; + F f_; + + explicit empty_closure(const F& f) : f_(f) {} + + template + auto operator()(std::ostream* msgs, Args... args) const { + return f_(args..., msgs); + } + size_t count_vars__() const { return 0; } + auto value_of__() const { return empty_closure(f_); } + auto shallow_copy__() const { return empty_closure(f_); } + auto deep_copy_vars__() const { return empty_closure(f_); } + void zero_adjoints__() const {} + double* accumulate_adjoints__(double* dest) const { return dest; } + template + Vari** save_varis(Vari** dest) const { + return dest; + } +}; + +template +struct one_arg_closure { + using captured_scalar_t__ = return_type_t; + using ValueOf__ = one_arg_closure()))>; + using CopyOf__ = one_arg_closure; + F f_; + T s_; + + explicit one_arg_closure(const F& f, const T& s) : f_(f), s_(s) {} + + template + auto operator()(std::ostream* msgs, Args... args) const { + return f_(s_, args..., msgs); + } + size_t count_vars__() const { return count_vars(s_); } + auto value_of__() const { return ValueOf__(f_, value_of(s_)); } + auto shallow_copy__() const { return one_arg_closure(f_, s_); } + auto deep_copy_vars__() const { + return one_arg_closure(f_, deep_copy_vars(s_)); + } + void zero_adjoints__() { zero_adjoints(s_); } + double* accumulate_adjoints__(double* dest) const { + return accumulate_adjoints(dest, s_); + } + template + Vari** save_varis__(Vari** dest) const { + return save_varis(dest, s_); + } +}; + +template +auto from_lambda(F f) { + return empty_closure(f); +} + +template +auto from_lambda(F f, T a) { + return one_arg_closure(f, a); +} + +template +struct lpdf_wrapper { + using captured_scalar_t__ = return_type_t; + using ValueOf__ + = lpdf_wrapper().value_of__()), false>; + using CopyOf__ + = lpdf_wrapper().copy_of__()), false>; + capture_type_t f_; + + explicit lpdf_wrapper(const F& f) : f_(f) {} + + template + auto with_propto() { + return lpdf_wrapper < Propto && propto, F, true > (f_); + } + + template + auto operator()(Args... args) const { + return f_.template operator() < Propto && propto > (args...); + } + size_t count_vars__() const { return count_vars(f_); } + auto value_of__() const { return ValueOf__(value_of(f_)); } + auto deep_copy_vars__() const { return CopyOf__(deep_copy_vars(f_)); } + auto copy_of__() const { return CopyOf__(f_.copy_of__()); } + void zero_adjoints__() { zero_adjoints(f_); } + double* accumulate_adjoints__(double* dest) const { + return accumulate_adjoints(dest, f_); + } + template + Vari** save_varis__(Vari** dest) const { + return save_varis(dest, f_); + } +}; + +struct reduce_sum_closure_adapter { + template + auto operator()(const std::vector& sub_slice, std::size_t start, + std::size_t end, std::ostream* msgs, const F& f, + Args... args) const { + return f(msgs, sub_slice, start, end, args...); + } +}; + +namespace internal { + +struct ode_closure_adapter { + template + auto operator()(const T0& t, const Eigen::Matrix& y, + std::ostream* msgs, const F& f, Args... args) const { + return f(msgs, t, y, args...); + } +}; + +} // namespace internal + +} // namespace math +} // namespace stan + +#endif diff --git a/stan/math/prim/functor/integrate_ode_rk45.hpp b/stan/math/prim/functor/integrate_ode_rk45.hpp index 1e568d6b1ea..afc3fdbb1cb 100644 --- a/stan/math/prim/functor/integrate_ode_rk45.hpp +++ b/stan/math/prim/functor/integrate_ode_rk45.hpp @@ -26,7 +26,7 @@ inline auto integrate_ode_rk45( ts, relative_tolerance, absolute_tolerance, max_num_steps, msgs, theta, x, x_int); - std::vector>> + std::vector>> y_converted; y_converted.reserve(y.size()); for (size_t i = 0; i < y.size(); ++i) diff --git a/stan/math/prim/functor/integrate_ode_std_vector_interface_adapter.hpp b/stan/math/prim/functor/integrate_ode_std_vector_interface_adapter.hpp index 9b0178e27b5..45011299553 100644 --- a/stan/math/prim/functor/integrate_ode_std_vector_interface_adapter.hpp +++ b/stan/math/prim/functor/integrate_ode_std_vector_interface_adapter.hpp @@ -1,6 +1,7 @@ #ifndef STAN_MATH_PRIM_FUNCTOR_INTEGRATE_ODE_STD_VECTOR_INTERFACE_ADAPTER_HPP #define STAN_MATH_PRIM_FUNCTOR_INTEGRATE_ODE_STD_VECTOR_INTERFACE_ADAPTER_HPP +#include #include #include #include @@ -19,11 +20,14 @@ namespace internal { * state as an Eigen::Matrix. The adapter converts to and from these forms * so that the old ODE interfaces can work. */ -template -struct integrate_ode_std_vector_interface_adapter { - const F f_; +template +struct integrate_ode_std_vector_interface_adapter_impl; - explicit integrate_ode_std_vector_interface_adapter(const F& f) : f_(f) {} +template +struct integrate_ode_std_vector_interface_adapter_impl { + const F& f_; + explicit integrate_ode_std_vector_interface_adapter_impl(const F& f) + : f_(f) {} template auto operator()(const T0& t, const Eigen::Matrix& y, @@ -34,6 +38,46 @@ struct integrate_ode_std_vector_interface_adapter { } }; +template +struct integrate_ode_std_vector_interface_adapter_impl { + using captured_scalar_t__ = typename F::captured_scalar_t__; + using ValueOf__ + = integrate_ode_std_vector_interface_adapter_impl; + F f_; + + explicit integrate_ode_std_vector_interface_adapter_impl(const F& f) + : f_(f) {} + + template + auto operator()(std::ostream* msgs, const T0& t, + const Eigen::Matrix& y, + const std::vector& theta, const std::vector& x, + const std::vector& x_int) const { + return to_vector(f_(msgs, t, to_array_1d(y), theta, x, x_int)); + } + + size_t count_vars__() const { return f_.count_vars__(); } + auto value_of__() const { return ValueOf__(f_.value_of__()); } + auto deep_copy_vars__() const { + return integrate_ode_std_vector_interface_adapter_impl( + f_.deep_copy_vars__()); + } + void zero_adjoints__() { f_.zero_adjoints__(); } + double* accumulate_adjoints__(double* dest) const { + return f_.accumulate_adjoints__(dest); + } + template + Vari** save_varis__(Vari** dest) const { + return f_.save_varis__(dest); + } +}; + +template +using integrate_ode_std_vector_interface_adapter + = integrate_ode_std_vector_interface_adapter_impl::value, + F>; + } // namespace internal } // namespace math diff --git a/stan/math/prim/functor/ode_rk45.hpp b/stan/math/prim/functor/ode_rk45.hpp index 8a90bd1e2c7..68eeb00043d 100644 --- a/stan/math/prim/functor/ode_rk45.hpp +++ b/stan/math/prim/functor/ode_rk45.hpp @@ -53,7 +53,8 @@ namespace math { * @return Solution to ODE at times \p ts */ template * = nullptr> + typename... Args, require_eigen_vector_t* = nullptr, + require_not_stan_closure_t* = nullptr> std::vector, Eigen::Dynamic, 1>> ode_rk45_tol_impl(const char* function_name, const F& f, const T_y0& y0_arg, @@ -158,6 +159,22 @@ ode_rk45_tol_impl(const char* function_name, const F& f, const T_y0& y0_arg, return y; } +template * = nullptr, + require_stan_closure_t* = nullptr> +std::vector, + Eigen::Dynamic, 1>> +ode_rk45_tol_impl(const char* function_name, const F& f, const T_y0& y0_arg, + T_t0 t0, const std::vector& ts, + double relative_tolerance, double absolute_tolerance, + long int max_num_steps, // NOLINT(runtime/int) + std::ostream* msgs, const Args&... args) { + internal::ode_closure_adapter f_adapter; + return ode_rk45_tol_impl(function_name, f_adapter, y0_arg, t0, ts, + relative_tolerance, absolute_tolerance, + max_num_steps, msgs, f, args...); +} + /** * Solve the ODE initial value problem y' = f(t, y), y(t0) = y0 at a set of * times, { t1, t2, t3, ... } using the non-stiff Runge-Kutta 45 solver in @@ -196,7 +213,7 @@ ode_rk45_tol_impl(const char* function_name, const F& f, const T_y0& y0_arg, */ template * = nullptr> -std::vector, +std::vector, Eigen::Dynamic, 1>> ode_rk45_tol(const F& f, const T_y0& y0_arg, T_t0 t0, const std::vector& ts, double relative_tolerance, @@ -242,7 +259,7 @@ ode_rk45_tol(const F& f, const T_y0& y0_arg, T_t0 t0, */ template * = nullptr> -std::vector, +std::vector, Eigen::Dynamic, 1>> ode_rk45(const F& f, const T_y0& y0, T_t0 t0, const std::vector& ts, std::ostream* msgs, const Args&... args) { diff --git a/stan/math/prim/meta.hpp b/stan/math/prim/meta.hpp index 500eb70864d..89768cc8626 100644 --- a/stan/math/prim/meta.hpp +++ b/stan/math/prim/meta.hpp @@ -213,6 +213,7 @@ #include #include #include +#include #include #include #include diff --git a/stan/math/prim/meta/is_stan_closure.hpp b/stan/math/prim/meta/is_stan_closure.hpp new file mode 100644 index 00000000000..82860278df6 --- /dev/null +++ b/stan/math/prim/meta/is_stan_closure.hpp @@ -0,0 +1,76 @@ +#ifndef STAN_MATH_PRIM_META_IS_STAN_CLOSURE_HPP +#define STAN_MATH_PRIM_META_IS_STAN_CLOSURE_HPP + +#include +#include +#include + +#include + +namespace stan { + +/** + * Checks if type is a closure object. + * @tparam The type to check + * @ingroup type_trait + */ +template +struct is_stan_closure : std::false_type {}; + +template +struct is_stan_closure> + : std::true_type {}; + +template +struct scalar_type> { + using type = typename T::captured_scalar_t__; +}; + +STAN_ADD_REQUIRE_UNARY(stan_closure, is_stan_closure, general_types); + +template +struct fn_return_type { + using type = double; +}; + +template +struct fn_return_type> { + using type = typename T::captured_scalar_t__; +}; + +/** + * Convenience type for the return type of the specified template + * parameters. + * + * @tparam F callable type + * @tparam Ts sequence of types + * @see return_type + * @ingroup type_trait + */ +template +using fn_return_type_t + = return_type_t::type, Args...>; + +template +struct capture_type; + +template +struct capture_type { + using type = const T&; +}; +template +struct capture_type>> { + using type = std::remove_reference_t; +}; +template +struct capture_type>> { + using type = typename std::remove_reference_t::CopyOf__; +}; +template +using capture_type_t = typename capture_type::type; + +} // namespace stan + +#endif diff --git a/stan/math/prim/meta/return_type.hpp b/stan/math/prim/meta/return_type.hpp index 79ef5816a9a..3f9afa4f504 100644 --- a/stan/math/prim/meta/return_type.hpp +++ b/stan/math/prim/meta/return_type.hpp @@ -190,8 +190,8 @@ struct return_type { template struct return_type { - using type - = scalar_lub_t, typename return_type::type>; + using type = scalar_lub_t, + typename return_type...>::type>; }; /** diff --git a/stan/math/rev/core/accumulate_adjoints.hpp b/stan/math/rev/core/accumulate_adjoints.hpp index e5b27354ebd..1a95dba46af 100644 --- a/stan/math/rev/core/accumulate_adjoints.hpp +++ b/stan/math/rev/core/accumulate_adjoints.hpp @@ -29,6 +29,10 @@ template * = nullptr, typename... Pargs> inline double* accumulate_adjoints(double* dest, EigT&& x, Pargs&&... args); +template * = nullptr, + require_not_st_arithmetic* = nullptr, typename... Pargs> +inline double* accumulate_adjoints(double* dest, F& f, Pargs&&... args); + template * = nullptr, typename... Pargs> inline double* accumulate_adjoints(double* dest, Arith&& x, Pargs&&... args); @@ -121,6 +125,27 @@ inline double* accumulate_adjoints(double* dest, EigT&& x, Pargs&&... args) { return accumulate_adjoints(dest + x.size(), std::forward(args)...); } +/** + * Accumulate adjoints from f (a closure type containing vars) + * into storage pointed to by dest, + * increment the adjoint storage pointer, + * recursively accumulate the adjoints of the rest of the arguments, + * and return final position of storage pointer. + * + * @tparam F A closure type capturing vars. + * @tparam Pargs Types of remaining arguments + * @param dest Pointer to where adjoints are to be accumulated + * @param f A closure holding vars to accumulate over + * @param args Further args to accumulate over + * @return Final position of adjoint storage pointer + */ +template *, require_not_st_arithmetic*, + typename... Pargs> +inline double* accumulate_adjoints(double* dest, F& f, Pargs&&... args) { + return accumulate_adjoints(f.accumulate_adjoints__(dest), + std::forward(args)...); +} + /** * Ignore arithmetic types. * diff --git a/stan/math/rev/core/count_vars.hpp b/stan/math/rev/core/count_vars.hpp index b0b536a27ab..e6463de54b5 100644 --- a/stan/math/rev/core/count_vars.hpp +++ b/stan/math/rev/core/count_vars.hpp @@ -29,6 +29,10 @@ inline size_t count_vars_impl(size_t count, EigT&& x, Pargs&&... args); template inline size_t count_vars_impl(size_t count, const var& x, Pargs&&... args); +template * = nullptr, + require_not_st_arithmetic* = nullptr, typename... Pargs> +inline size_t count_vars_impl(size_t count, const F& f, Pargs&&... args); + template >* = nullptr, typename... Pargs> inline size_t count_vars_impl(size_t count, Arith& x, Pargs&&... args); @@ -110,6 +114,26 @@ inline size_t count_vars_impl(size_t count, const var& x, Pargs&&... args) { return count_vars_impl(count + 1, std::forward(args)...); } +/** + * Count the number of vars in f (a closure capturing vars), + * add it to the running total, + * count the number of vars in the remaining arguments + * and return the result. + * + * @tparam F A closure type + * @tparam Pargs Types of remaining arguments + * @param[in] count The current count of the number of vars + * @param[in] f A closure holding vars + * @param[in] args objects to be forwarded to recursive call of + * `count_vars_impl` + */ +template *, require_not_st_arithmetic*, + typename... Pargs> +inline size_t count_vars_impl(size_t count, const F& f, Pargs&&... args) { + return count_vars_impl(count + f.count_vars__(), + std::forward(args)...); +} + /** * Arguments without vars contribute zero to the total number of vars. * diff --git a/stan/math/rev/core/deep_copy_vars.hpp b/stan/math/rev/core/deep_copy_vars.hpp index 06561d1a9e0..cf2e7479da5 100644 --- a/stan/math/rev/core/deep_copy_vars.hpp +++ b/stan/math/rev/core/deep_copy_vars.hpp @@ -81,6 +81,19 @@ inline auto deep_copy_vars(EigT&& arg) { .eval(); } +/** + * Copy the vars in f but reallocate new varis for them + * + * @tparam F A closure type + * @param f A closure of vars + * @return A new std::vector of vars + */ +template * = nullptr, + require_not_arithmetic_t>* = nullptr> +inline auto deep_copy_vars(F&& f) { + return f.deep_copy_vars__(); +} + } // namespace math } // namespace stan diff --git a/stan/math/rev/core/save_varis.hpp b/stan/math/rev/core/save_varis.hpp index c53a5390539..6e19d54fa53 100644 --- a/stan/math/rev/core/save_varis.hpp +++ b/stan/math/rev/core/save_varis.hpp @@ -29,6 +29,10 @@ template * = nullptr, typename... Pargs> inline vari** save_varis(vari** dest, EigT&& x, Pargs&&... args); +template * = nullptr, + require_not_st_arithmetic* = nullptr, typename... Pargs> +inline vari** save_varis(vari** dest, F& f, Pargs&&... args); + template * = nullptr, typename... Pargs> inline vari** save_varis(vari** dest, Arith&& x, Pargs&&... args); @@ -118,6 +122,25 @@ inline vari** save_varis(vari** dest, EigT&& x, Pargs&&... args) { return save_varis(dest + x.size(), std::forward(args)...); } +/** + * Save the vari pointers in f into the memory pointed to by dest, + * increment the dest storage pointer, + * recursively call save_varis on the rest of the arguments, + * and return the final value of the dest storage pointer. + * + * @tparam F A closure type with var value type + * @tparam Pargs Types of remaining arguments + * @param[in, out] dest Pointer to where vari pointers are saved + * @param[in] f A closure capturing vars + * @param[in] args Additional arguments to have their varis saved + * @return Final position of dest pointer + */ +template *, require_not_st_arithmetic*, + typename... Pargs> +inline vari** save_varis(vari** dest, F& f, Pargs&&... args) { + return save_varis(f.save_varis__(dest), std::forward(args)...); +} + /** * Ignore arithmetic types. * diff --git a/stan/math/rev/core/zero_adjoints.hpp b/stan/math/rev/core/zero_adjoints.hpp index 36368d443ee..1e10e8621ea 100644 --- a/stan/math/rev/core/zero_adjoints.hpp +++ b/stan/math/rev/core/zero_adjoints.hpp @@ -19,6 +19,10 @@ inline void zero_adjoints(var& x, Pargs&... args); template inline void zero_adjoints(Eigen::Matrix& x, Pargs&... args); +template * = nullptr, + require_not_st_arithmetic* = nullptr> +inline void zero_adjoints(F& f, Pargs&... args); + template * = nullptr> inline void zero_adjoints(std::vector& x, Pargs&... args); @@ -75,6 +79,23 @@ inline void zero_adjoints(Eigen::Matrix& x, Pargs&... args) { zero_adjoints(args...); } +/** + * Zero the adjoints of the varis of every var in a closure. + * Recursively call zero_adjoints on the rest of the arguments. + * + * @tparam F type of current argument + * @tparam Pargs type of rest of arguments + * + * @param f current argument + * @param args rest of arguments to zero + */ +template *, + require_not_st_arithmetic*> +inline void zero_adjoints(F& f, Pargs&... args) { + f.zero_adjoints__(); + zero_adjoints(args...); +} + /** * Zero the adjoints of every element in a vector. Recursively call * zero_adjoints on the rest of the arguments. diff --git a/stan/math/rev/functor/integrate_ode_adams.hpp b/stan/math/rev/functor/integrate_ode_adams.hpp index 0cba70a321e..8abcf660173 100644 --- a/stan/math/rev/functor/integrate_ode_adams.hpp +++ b/stan/math/rev/functor/integrate_ode_adams.hpp @@ -15,7 +15,7 @@ namespace math { */ template -std::vector>> +std::vector>> integrate_ode_adams(const F& f, const std::vector& y0, const T_t0& t0, const std::vector& ts, const std::vector& theta, @@ -29,7 +29,7 @@ integrate_ode_adams(const F& f, const std::vector& y0, const T_t0& t0, t0, ts, relative_tolerance, absolute_tolerance, max_num_steps, msgs, theta, x, x_int); - std::vector>> + std::vector>> y_converted; for (size_t i = 0; i < y.size(); ++i) y_converted.push_back(to_array_1d(y[i])); diff --git a/stan/math/rev/functor/integrate_ode_bdf.hpp b/stan/math/rev/functor/integrate_ode_bdf.hpp index c3877bdb875..03bcda04234 100644 --- a/stan/math/rev/functor/integrate_ode_bdf.hpp +++ b/stan/math/rev/functor/integrate_ode_bdf.hpp @@ -15,7 +15,7 @@ namespace math { */ template -std::vector>> +std::vector>> integrate_ode_bdf(const F& f, const std::vector& y0, const T_t0& t0, const std::vector& ts, const std::vector& theta, @@ -29,7 +29,7 @@ integrate_ode_bdf(const F& f, const std::vector& y0, const T_t0& t0, ts, relative_tolerance, absolute_tolerance, max_num_steps, msgs, theta, x, x_int); - std::vector>> + std::vector>> y_converted; for (size_t i = 0; i < y.size(); ++i) y_converted.push_back(to_array_1d(y[i])); diff --git a/stan/math/rev/functor/ode_adams.hpp b/stan/math/rev/functor/ode_adams.hpp index ee0bdafbbd5..c89663126f6 100644 --- a/stan/math/rev/functor/ode_adams.hpp +++ b/stan/math/rev/functor/ode_adams.hpp @@ -45,7 +45,8 @@ namespace math { * @return Solution to ODE at times \p ts */ template * = nullptr> + typename... T_Args, require_eigen_col_vector_t* = nullptr, + require_not_stan_closure_t* = nullptr> std::vector, Eigen::Dynamic, 1>> ode_adams_tol_impl(const char* function_name, const F& f, const T_y0& y0, @@ -65,6 +66,22 @@ ode_adams_tol_impl(const char* function_name, const F& f, const T_y0& y0, args_ref_tuple); } +template * = nullptr, + require_stan_closure_t* = nullptr> +std::vector, + Eigen::Dynamic, 1>> +ode_adams_tol_impl(const char* function_name, const F& f, const T_y0& y0, + const T_t0& t0, const std::vector& ts, + double relative_tolerance, double absolute_tolerance, + long int max_num_steps, // NOLINT(runtime/int) + std::ostream* msgs, const T_Args&... args) { + internal::ode_closure_adapter f_adapter; + return ode_adams_tol_impl(function_name, f_adapter, y0, t0, ts, + relative_tolerance, absolute_tolerance, + max_num_steps, msgs, f, args...); +} + /** * Solve the ODE initial value problem y' = f(t, y), y(t0) = y0 at a set of * times, { t1, t2, t3, ... } using the non-stiff Adams-Moulton solver from diff --git a/stan/math/rev/functor/ode_bdf.hpp b/stan/math/rev/functor/ode_bdf.hpp index a07af2e3339..5a5205ab480 100644 --- a/stan/math/rev/functor/ode_bdf.hpp +++ b/stan/math/rev/functor/ode_bdf.hpp @@ -46,7 +46,8 @@ namespace math { * @return Solution to ODE at times \p ts */ template * = nullptr> + typename... T_Args, require_eigen_col_vector_t* = nullptr, + require_not_stan_closure_t* = nullptr> std::vector, Eigen::Dynamic, 1>> ode_bdf_tol_impl(const char* function_name, const F& f, const T_y0& y0, @@ -66,6 +67,22 @@ ode_bdf_tol_impl(const char* function_name, const F& f, const T_y0& y0, args_ref_tuple); } +template * = nullptr, + require_stan_closure_t* = nullptr> +std::vector, + Eigen::Dynamic, 1>> +ode_bdf_tol_impl(const char* function_name, const F& f, const T_y0& y0, + const T_t0& t0, const std::vector& ts, + double relative_tolerance, double absolute_tolerance, + long int max_num_steps, // NOLINT(runtime/int) + std::ostream* msgs, const T_Args&... args) { + internal::ode_closure_adapter f_adapter; + return ode_bdf_tol_impl(function_name, f_adapter, y0, t0, ts, + relative_tolerance, absolute_tolerance, + max_num_steps, msgs, f, args...); +} + /** * Solve the ODE initial value problem y' = f(t, y), y(t0) = y0 at a set of * times, { t1, t2, t3, ... } using the stiff backward differentiation formula diff --git a/test/unit/math/rev/functor/integrate_ode_rk45_rev_test.cpp b/test/unit/math/rev/functor/integrate_ode_rk45_rev_test.cpp index 1847bc7bb8d..20e2c88840f 100644 --- a/test/unit/math/rev/functor/integrate_ode_rk45_rev_test.cpp +++ b/test/unit/math/rev/functor/integrate_ode_rk45_rev_test.cpp @@ -271,3 +271,67 @@ TEST(StanAgradRevOde_integrate_ode_rk45, t0_as_param_AD) { 1e-10, 1e6); test_ad(); } + +TEST(StanAgradRevOde_integrate_ode_rk45, closure) { + using stan::math::integrate_ode_rk45; + using stan::math::to_var; + using stan::math::value_of; + using stan::math::var; + const double t0 = 0.0; + std::ostream* msgs = NULL; + + stan::math::var a0 = 0.0; + auto ode = from_lambda( + [](const auto& a, const auto& t_in, const auto& y_in, const auto& theta, + const std::vector& x, const std::vector& x_int, + std::ostream* msgs) { + if (y_in.size() != 2) + throw std::domain_error( + "this function was called with inconsistent state"); + + std::vector> res; + res.push_back(y_in.at(1)); + res.push_back(-y_in.at(0) - theta.at(0) * y_in.at(1)); + + return res; + }, + a0); + + std::vector theta{0.15}; + std::vector y0{1.0, 0.0}; + std::vector ts = {5.0, 10.0}; + + std::vector x; + std::vector x_int; + std::vector y0v = to_var(y0); + std::vector thetav = to_var(theta); + stan::math::var t0v = to_var(t0); + + std::vector> res; + auto test_ad = [&res, &t0v, &a0, &theta, &x, &x_int, &msgs]() { + res[0][0].grad(); + EXPECT_FLOAT_EQ(t0v.adj(), -0.66360742442816977871); + stan::math::set_zero_all_adjoints(); + res[0][1].grad(); + EXPECT_FLOAT_EQ(t0v.adj(), 0.23542843380353062344); + stan::math::set_zero_all_adjoints(); + res[1][0].grad(); + EXPECT_FLOAT_EQ(t0v.adj(), -0.2464078910913158893); + stan::math::set_zero_all_adjoints(); + res[1][1].grad(); + EXPECT_FLOAT_EQ(t0v.adj(), -0.38494826636037426937); + stan::math::set_zero_all_adjoints(); + }; + res = integrate_ode_rk45(ode, y0, t0v, ts, theta, x, x_int, nullptr, 1e-10, + 1e-10, 1e6); + test_ad(); + res = integrate_ode_rk45(ode, y0v, t0v, ts, theta, x, x_int, nullptr, 1e-10, + 1e-10, 1e6); + test_ad(); + res = integrate_ode_rk45(ode, y0, t0v, ts, thetav, x, x_int, nullptr, 1e-10, + 1e-10, 1e6); + test_ad(); + res = integrate_ode_rk45(ode, y0v, t0v, ts, thetav, x, x_int, nullptr, 1e-10, + 1e-10, 1e6); + test_ad(); +} diff --git a/test/unit/math/rev/functor/ode_rk45_rev_test.cpp b/test/unit/math/rev/functor/ode_rk45_rev_test.cpp index a1665e53e35..6d021038b06 100644 --- a/test/unit/math/rev/functor/ode_rk45_rev_test.cpp +++ b/test/unit/math/rev/functor/ode_rk45_rev_test.cpp @@ -265,6 +265,86 @@ TEST(StanMathOde_ode_rk45, scalar_std_vector_args) { EXPECT_NEAR(a1[0].adj(), -0.50107310888, 1e-5); } +TEST(StanMathOde_ode_rk45, closure_var) { + using stan::math::var; + + Eigen::VectorXd y0 = Eigen::VectorXd::Zero(1); + double t0 = 0.0; + std::vector ts = {1.1}; + + var a0 = 0.75; + std::vector a1 = {0.75}; + + auto f = stan::math::from_lambda( + [&](const auto& a, const auto& t, const auto& y, const auto& b, + std::ostream* msgs) { + return stan::test::Cos2Arg()(t, y, msgs, a, b); + }, + a0); + + var output = stan::math::ode_rk45(f, y0, t0, ts, nullptr, a1)[0][0]; + + output.grad(); + + EXPECT_NEAR(output.val(), 0.66457668563, 1e-5); + EXPECT_NEAR(a0.adj(), -0.50107310888, 1e-5); + EXPECT_NEAR(a1[0].adj(), -0.50107310888, 1e-5); +} + +TEST(StanMathOde_ode_rk45, closure_double) { + using stan::math::var; + + Eigen::VectorXd y0 = Eigen::VectorXd::Zero(1); + double t0 = 0.0; + std::vector ts = {1.1}; + + var a0 = 0.75; + std::vector a1 = {0.75}; + + auto f = stan::math::from_lambda( + [](const auto& a, const auto& t, const auto& y, const auto& b, + std::ostream* msgs) { + return stan::test::Cos2Arg()(t, y, msgs, a, b); + }, + a0); + + var output = stan::math::ode_rk45(f, y0, t0, ts, nullptr, a1)[0][0]; + + output.grad(); + + EXPECT_NEAR(output.val(), 0.66457668563, 1e-5); + EXPECT_NEAR(a0.adj(), -0.50107310888, 1e-5); +} + +TEST(StanMathOde_ode_rk45, higher_order) { + using stan::math::var; + + Eigen::VectorXd y0 = Eigen::VectorXd::Zero(1); + double t0 = 0.0; + std::vector ts = {1.1}; + + var a0 = 0.75; + std::vector a1 = {0.75}; + + auto f = stan::math::from_lambda( + [](const auto& a, const auto& t, const auto& y, const auto& b, + std::ostream* msgs) { + return stan::test::Cos2Arg()(t, y, msgs, a, b); + }, + a0); + + auto wrapper + = [](const auto& t, const auto& y, std::ostream* msgs, const auto& fa, + const auto& b) { return fa(msgs, t, y, b); }; + + var output = stan::math::ode_rk45(wrapper, y0, t0, ts, nullptr, f, a1)[0][0]; + + output.grad(); + + EXPECT_NEAR(output.val(), 0.66457668563, 1e-5); + EXPECT_NEAR(a0.adj(), -0.50107310888, 1e-5); +} + TEST(StanMathOde_ode_rk45, std_vector_std_vector_args) { using stan::math::var; diff --git a/test/unit/math/rev/functor/reduce_sum_closure_test.cpp b/test/unit/math/rev/functor/reduce_sum_closure_test.cpp new file mode 100644 index 00000000000..4174a238b84 --- /dev/null +++ b/test/unit/math/rev/functor/reduce_sum_closure_test.cpp @@ -0,0 +1,83 @@ +#include +#include +#include +#include +#include +#include +#include +#include + +TEST(StanMathRev_reduce_sum, grouped_gradient_closure) { + using stan::math::from_lambda; + using stan::math::reduce_sum_closure_adapter; + using stan::math::var; + using stan::math::test::get_new_msg; + + double lambda_d = 10.0; + const std::size_t groups = 10; + const std::size_t elems_per_group = 1000; + const std::size_t elems = groups * elems_per_group; + + std::vector data(elems); + std::vector gidx(elems); + + for (std::size_t i = 0; i != elems; ++i) { + data[i] = i; + gidx[i] = i / elems_per_group; + } + + std::vector vlambda_v; + + for (std::size_t i = 0; i != groups; ++i) + vlambda_v.push_back(i + 0.2); + + var lambda_v = vlambda_v[0]; + + auto functor = from_lambda( + [](auto& lambda, auto& slice, std::size_t start, std::size_t end, + auto& gidx, std::ostream* msgs) { + const std::size_t num_terms = end - start + 1; + std::decay_t lambda_slice(num_terms); + for (std::size_t i = 0; i != num_terms; ++i) + lambda_slice[i] = lambda[gidx[start + i]]; + return stan::math::poisson_lpmf(slice, lambda_slice); + }, + vlambda_v); + + var poisson_lpdf = stan::math::reduce_sum( + data, 5, get_new_msg(), functor, gidx); + + std::vector vref_lambda_v; + for (std::size_t i = 0; i != elems; ++i) { + vref_lambda_v.push_back(vlambda_v[gidx[i]]); + } + var lambda_ref = vlambda_v[0]; + var poisson_lpdf_ref = stan::math::poisson_lpmf(data, vref_lambda_v); + + EXPECT_FLOAT_EQ(value_of(poisson_lpdf), value_of(poisson_lpdf_ref)); + + stan::math::grad(poisson_lpdf_ref.vi_); + const double lambda_ref_adj = lambda_ref.adj(); + + stan::math::set_zero_all_adjoints(); + stan::math::grad(poisson_lpdf.vi_); + const double lambda_adj = lambda_v.adj(); + + EXPECT_FLOAT_EQ(lambda_adj, lambda_ref_adj) + << "ref value of poisson lpdf : " << poisson_lpdf_ref.val() << std::endl + << "ref gradient wrt to lambda: " << lambda_ref_adj << std::endl + << "value of poisson lpdf : " << poisson_lpdf.val() << std::endl + << "gradient wrt to lambda: " << lambda_adj << std::endl; + + var poisson_lpdf_static + = stan::math::reduce_sum_static( + data, 5, get_new_msg(), functor, gidx); + + stan::math::set_zero_all_adjoints(); + stan::math::grad(poisson_lpdf_static.vi_); + const double lambda_adj_static = lambda_v.adj(); + EXPECT_FLOAT_EQ(lambda_adj_static, lambda_ref_adj); + stan::math::recover_memory(); + + stan::math::recover_memory(); +} From 7382327ae8eac5e38e17237e56f095992a4ca63c Mon Sep 17 00:00:00 2001 From: Niko Huurre Date: Sun, 21 Feb 2021 17:03:13 +0200 Subject: [PATCH 02/25] some docs --- stan/math/prim/functor/closure_adapter.hpp | 6 ++++++ stan/math/prim/meta/is_stan_closure.hpp | 9 +++++++++ 2 files changed, 15 insertions(+) diff --git a/stan/math/prim/functor/closure_adapter.hpp b/stan/math/prim/functor/closure_adapter.hpp index 378b5390395..b620d20e220 100644 --- a/stan/math/prim/functor/closure_adapter.hpp +++ b/stan/math/prim/functor/closure_adapter.hpp @@ -63,11 +63,17 @@ struct one_arg_closure { } }; +/** + * Create a closure object from a callable. + */ template auto from_lambda(F f) { return empty_closure(f); } +/** + * Create a closure that captures a single argument. + */ template auto from_lambda(F f, T a) { return one_arg_closure(f, a); diff --git a/stan/math/prim/meta/is_stan_closure.hpp b/stan/math/prim/meta/is_stan_closure.hpp index 82860278df6..30f4db0861a 100644 --- a/stan/math/prim/meta/is_stan_closure.hpp +++ b/stan/math/prim/meta/is_stan_closure.hpp @@ -58,16 +58,25 @@ template struct capture_type { using type = const T&; }; + template struct capture_type>> { using type = std::remove_reference_t; }; + template struct capture_type>> { using type = typename std::remove_reference_t::CopyOf__; }; + +/** + * Type for things captured either by const reference or by copy. + * + * @tparam T type of object being captured + * @tparam Ref true if reference, false if copy + */ template using capture_type_t = typename capture_type::type; From ee2160010a1bb70b392303d40b6ecb51e1c20e1f Mon Sep 17 00:00:00 2001 From: Niko Huurre Date: Sun, 21 Feb 2021 21:02:11 +0200 Subject: [PATCH 03/25] fix headers --- stan/math/prim/functor/closure_adapter.hpp | 1 + stan/math/prim/functor/ode_rk45.hpp | 1 + stan/math/rev/functor/ode_adams.hpp | 1 + stan/math/rev/functor/ode_bdf.hpp | 5 +++-- 4 files changed, 6 insertions(+), 2 deletions(-) diff --git a/stan/math/prim/functor/closure_adapter.hpp b/stan/math/prim/functor/closure_adapter.hpp index b620d20e220..77c32077995 100644 --- a/stan/math/prim/functor/closure_adapter.hpp +++ b/stan/math/prim/functor/closure_adapter.hpp @@ -1,6 +1,7 @@ #ifndef STAN_MATH_PRIM_FUNCTOR_CLOSURE_ADAPTER_HPP #define STAN_MATH_PRIM_FUNCTOR_CLOSURE_ADAPTER_HPP +#include #include #include diff --git a/stan/math/prim/functor/ode_rk45.hpp b/stan/math/prim/functor/ode_rk45.hpp index 68eeb00043d..b60a18ac3a5 100644 --- a/stan/math/prim/functor/ode_rk45.hpp +++ b/stan/math/prim/functor/ode_rk45.hpp @@ -3,6 +3,7 @@ #include #include +#include #include #include #include diff --git a/stan/math/rev/functor/ode_adams.hpp b/stan/math/rev/functor/ode_adams.hpp index c89663126f6..9ebeb026f94 100644 --- a/stan/math/rev/functor/ode_adams.hpp +++ b/stan/math/rev/functor/ode_adams.hpp @@ -3,6 +3,7 @@ #include #include +#include #include #include #include diff --git a/stan/math/rev/functor/ode_bdf.hpp b/stan/math/rev/functor/ode_bdf.hpp index 5a5205ab480..8fee90cfdbd 100644 --- a/stan/math/rev/functor/ode_bdf.hpp +++ b/stan/math/rev/functor/ode_bdf.hpp @@ -3,6 +3,7 @@ #include #include +#include #include #include #include @@ -79,8 +80,8 @@ ode_bdf_tol_impl(const char* function_name, const F& f, const T_y0& y0, std::ostream* msgs, const T_Args&... args) { internal::ode_closure_adapter f_adapter; return ode_bdf_tol_impl(function_name, f_adapter, y0, t0, ts, - relative_tolerance, absolute_tolerance, - max_num_steps, msgs, f, args...); + relative_tolerance, absolute_tolerance, max_num_steps, + msgs, f, args...); } /** From a595b43b0512beee06c75ceb6c57dfa3c8743665 Mon Sep 17 00:00:00 2001 From: Niko Huurre Date: Tue, 23 Feb 2021 13:03:42 +0200 Subject: [PATCH 04/25] add from_lambda() for suffix functions, and fix integrate_ode adapter --- stan/math/prim/functor/closure_adapter.hpp | 120 +++++++++++++++--- ...grate_ode_std_vector_interface_adapter.hpp | 26 ++-- 2 files changed, 117 insertions(+), 29 deletions(-) diff --git a/stan/math/prim/functor/closure_adapter.hpp b/stan/math/prim/functor/closure_adapter.hpp index 77c32077995..0035aa949cb 100644 --- a/stan/math/prim/functor/closure_adapter.hpp +++ b/stan/math/prim/functor/closure_adapter.hpp @@ -13,7 +13,6 @@ struct empty_closure { using captured_scalar_t__ = double; using ValueOf__ = empty_closure; using CopyOf__ = empty_closure; - static const size_t vars_count__ = 0; F f_; explicit empty_closure(const F& f) : f_(f) {} @@ -23,9 +22,9 @@ struct empty_closure { return f_(args..., msgs); } size_t count_vars__() const { return 0; } - auto value_of__() const { return empty_closure(f_); } - auto shallow_copy__() const { return empty_closure(f_); } - auto deep_copy_vars__() const { return empty_closure(f_); } + auto value_of__() const { return ValueOf__(f_); } + auto copy_of__() const { return CopyOf__(f_); } + auto deep_copy_vars__() const { return CopyOf__(f_); } void zero_adjoints__() const {} double* accumulate_adjoints__(double* dest) const { return dest; } template @@ -34,13 +33,14 @@ struct empty_closure { } }; -template +template struct one_arg_closure { using captured_scalar_t__ = return_type_t; - using ValueOf__ = one_arg_closure()))>; - using CopyOf__ = one_arg_closure; + using ValueOf__ + = one_arg_closure()))>; + using CopyOf__ = one_arg_closure; F f_; - T s_; + capture_type_t s_; explicit one_arg_closure(const F& f, const T& s) : f_(f), s_(s) {} @@ -50,10 +50,8 @@ struct one_arg_closure { } size_t count_vars__() const { return count_vars(s_); } auto value_of__() const { return ValueOf__(f_, value_of(s_)); } - auto shallow_copy__() const { return one_arg_closure(f_, s_); } - auto deep_copy_vars__() const { - return one_arg_closure(f_, deep_copy_vars(s_)); - } + auto copy_of__() const { return CopyOf__(f_, s_); } + auto deep_copy_vars__() const { return CopyOf__(f_, deep_copy_vars(s_)); } void zero_adjoints__() { zero_adjoints(s_); } double* accumulate_adjoints__(double* dest) const { return accumulate_adjoints(dest, s_); @@ -64,11 +62,88 @@ struct one_arg_closure { } }; +template +struct empty_closure_rng { + using captured_scalar_t__ = double; + using ValueOf__ = empty_closure_rng; + using CopyOf__ = empty_closure_rng; + F f_; + + explicit empty_closure_rng(const F& f) : f_(f) {} + + template + auto operator()(const Rng& rng, std::ostream* msgs, Args... args) const { + return f_(args..., rng, msgs); + } + size_t count_vars__() const { return 0; } + auto value_of__() const { return ValueOf__(f_); } + auto copy_of__() const { return CopyOf__(f_); } + auto deep_copy_vars__() const { return CopyOf__(f_); } + void zero_adjoints__() const {} + double* accumulate_adjoints__(double* dest) const { return dest; } + template + Vari** save_varis(Vari** dest) const { + return dest; + } +}; + +template +struct empty_closure_lpdf { + using captured_scalar_t__ = double; + using ValueOf__ = empty_closure_lpdf; + using CopyOf__ = empty_closure_lpdf; + F f_; + + explicit empty_closure_lpdf(const F& f) : f_(f) {} + + template + auto operator()(std::ostream* msgs, Args... args) const { + return f_.template operator()(args..., msgs); + } + size_t count_vars__() const { return 0; } + auto value_of__() const { return ValueOf__(f_); } + auto copy_of__() const { return CopyOf__(f_); } + auto deep_copy_vars__() const { return CopyOf__(f_); } + void zero_adjoints__() const {} + double* accumulate_adjoints__(double* dest) const { return dest; } + template + Vari** save_varis(Vari** dest) const { + return dest; + } +}; + +template +struct empty_closure_lp { + using captured_scalar_t__ = double; + using ValueOf__ = empty_closure_lp; + using CopyOf__ = empty_closure_lp; + static const size_t vars_count__ = 0; + F f_; + + explicit empty_closure_lp(const F& f) : f_(f) {} + + template + auto operator()(T_lp_accum& lp, T_lp& lp_accum, std::ostream* msgs, + Args... args) const { + return f_(args..., lp, lp_accum, msgs); + } + size_t count_vars__() const { return 0; } + auto value_of__() const { return ValueOf__(f_); } + auto copy_of__() const { return CopyOf__(f_); } + auto deep_copy_vars__() const { return CopyOf__(f_); } + void zero_adjoints__() const {} + double* accumulate_adjoints__(double* dest) const { return dest; } + template + Vari** save_varis(Vari** dest) const { + return dest; + } +}; + /** * Create a closure object from a callable. */ template -auto from_lambda(F f) { +auto from_lambda(const F& f) { return empty_closure(f); } @@ -76,8 +151,23 @@ auto from_lambda(F f) { * Create a closure that captures a single argument. */ template -auto from_lambda(F f, T a) { - return one_arg_closure(f, a); +auto from_lambda(const F& f, const T& a) { + return one_arg_closure(f, a); +} + +template +auto rng_from_lambda(const F& f) { + return empty_closure_rng(f); +} + +template +auto lpdf_from_lambda(const F& f) { + return empty_closure_lpdf(f); +} + +template +auto lp_from_lambda(const F& f) { + return empty_closure_lp(f); } template diff --git a/stan/math/prim/functor/integrate_ode_std_vector_interface_adapter.hpp b/stan/math/prim/functor/integrate_ode_std_vector_interface_adapter.hpp index 45011299553..a50794a1680 100644 --- a/stan/math/prim/functor/integrate_ode_std_vector_interface_adapter.hpp +++ b/stan/math/prim/functor/integrate_ode_std_vector_interface_adapter.hpp @@ -20,11 +20,11 @@ namespace internal { * state as an Eigen::Matrix. The adapter converts to and from these forms * so that the old ODE interfaces can work. */ -template +template struct integrate_ode_std_vector_interface_adapter_impl; -template -struct integrate_ode_std_vector_interface_adapter_impl { +template +struct integrate_ode_std_vector_interface_adapter_impl { const F& f_; explicit integrate_ode_std_vector_interface_adapter_impl(const F& f) : f_(f) {} @@ -38,13 +38,14 @@ struct integrate_ode_std_vector_interface_adapter_impl { } }; -template -struct integrate_ode_std_vector_interface_adapter_impl { +template +struct integrate_ode_std_vector_interface_adapter_impl { using captured_scalar_t__ = typename F::captured_scalar_t__; - using ValueOf__ - = integrate_ode_std_vector_interface_adapter_impl; - F f_; + using ValueOf__ = integrate_ode_std_vector_interface_adapter_impl< + true, typename F::ValueOf__, false>; + using CopyOf__ = integrate_ode_std_vector_interface_adapter_impl< + true, typename F::CopyOf__, false>; + capture_type_t f_; explicit integrate_ode_std_vector_interface_adapter_impl(const F& f) : f_(f) {} @@ -59,10 +60,7 @@ struct integrate_ode_std_vector_interface_adapter_impl { size_t count_vars__() const { return f_.count_vars__(); } auto value_of__() const { return ValueOf__(f_.value_of__()); } - auto deep_copy_vars__() const { - return integrate_ode_std_vector_interface_adapter_impl( - f_.deep_copy_vars__()); - } + auto deep_copy_vars__() const { return CopyOf__(f_.deep_copy_vars__()); } void zero_adjoints__() { f_.zero_adjoints__(); } double* accumulate_adjoints__(double* dest) const { return f_.accumulate_adjoints__(dest); @@ -76,7 +74,7 @@ struct integrate_ode_std_vector_interface_adapter_impl { template using integrate_ode_std_vector_interface_adapter = integrate_ode_std_vector_interface_adapter_impl::value, - F>; + F, true>; } // namespace internal From 29c165f00257dbc4a789cacf1754b6c7832268bb Mon Sep 17 00:00:00 2001 From: Niko Huurre Date: Sat, 6 Mar 2021 12:27:07 +0200 Subject: [PATCH 05/25] add some minimal docs --- stan/math/prim/functor/closure_adapter.hpp | 66 ++++++++++++++----- stan/math/rev/core/deep_copy_vars.hpp | 4 +- .../functor/integrate_ode_rk45_rev_test.cpp | 6 +- 3 files changed, 56 insertions(+), 20 deletions(-) diff --git a/stan/math/prim/functor/closure_adapter.hpp b/stan/math/prim/functor/closure_adapter.hpp index 0035aa949cb..ad2d37ece49 100644 --- a/stan/math/prim/functor/closure_adapter.hpp +++ b/stan/math/prim/functor/closure_adapter.hpp @@ -7,7 +7,11 @@ namespace stan { namespace math { +namespace internal { +/** + * A closure that wraps a C++ lambda. + */ template struct empty_closure { using captured_scalar_t__ = double; @@ -33,6 +37,9 @@ struct empty_closure { } }; +/** + * A closure that holds one autodiffable capture. + */ template struct one_arg_closure { using captured_scalar_t__ = return_type_t; @@ -62,6 +69,9 @@ struct one_arg_closure { } }; +/** + * A closure that takes rng argument. + */ template struct empty_closure_rng { using captured_scalar_t__ = double; @@ -87,6 +97,9 @@ struct empty_closure_rng { } }; +/** + * A closure that can be called with `propto` template argument. + */ template struct empty_closure_lpdf { using captured_scalar_t__ = double; @@ -112,6 +125,9 @@ struct empty_closure_lpdf { } }; +/** + * A closure that accesses logprob accumulator. + */ template struct empty_closure_lp { using captured_scalar_t__ = double; @@ -139,12 +155,25 @@ struct empty_closure_lp { } }; +/** + * Higher-order functor suitable for calling a closure inside variadic ODE solvers. + */ +struct ode_closure_adapter { + template + auto operator()(const T0& t, const Eigen::Matrix& y, + std::ostream* msgs, const F& f, Args... args) const { + return f(msgs, t, y, args...); + } +}; + +} // namespace internal + /** * Create a closure object from a callable. */ template auto from_lambda(const F& f) { - return empty_closure(f); + return internal::empty_closure(f); } /** @@ -152,24 +181,36 @@ auto from_lambda(const F& f) { */ template auto from_lambda(const F& f, const T& a) { - return one_arg_closure(f, a); + return internal::one_arg_closure(f, a); } +/** + * Create a closure from an rng functor. + */ template auto rng_from_lambda(const F& f) { - return empty_closure_rng(f); + return internal::empty_closure_rng(f); } +/** + * Create a closure from an lpdf functor. + */ template auto lpdf_from_lambda(const F& f) { - return empty_closure_lpdf(f); + return internal::empty_closure_lpdf(f); } +/** + * Create a closure from a functor that needs access to logprob accumulator. + */ template auto lp_from_lambda(const F& f) { - return empty_closure_lp(f); + return internal::empty_closure_lp(f); } +/** + * A wrapper that sets propto template argument when calling the inner closure. + */ template struct lpdf_wrapper { using captured_scalar_t__ = return_type_t; @@ -204,6 +245,9 @@ struct lpdf_wrapper { } }; +/** + * Higher-order functor that invokes a closure inside a reduce_sum call. + */ struct reduce_sum_closure_adapter { template auto operator()(const std::vector& sub_slice, std::size_t start, @@ -213,18 +257,6 @@ struct reduce_sum_closure_adapter { } }; -namespace internal { - -struct ode_closure_adapter { - template - auto operator()(const T0& t, const Eigen::Matrix& y, - std::ostream* msgs, const F& f, Args... args) const { - return f(msgs, t, y, args...); - } -}; - -} // namespace internal - } // namespace math } // namespace stan diff --git a/stan/math/rev/core/deep_copy_vars.hpp b/stan/math/rev/core/deep_copy_vars.hpp index cf2e7479da5..047bb62ef73 100644 --- a/stan/math/rev/core/deep_copy_vars.hpp +++ b/stan/math/rev/core/deep_copy_vars.hpp @@ -85,8 +85,8 @@ inline auto deep_copy_vars(EigT&& arg) { * Copy the vars in f but reallocate new varis for them * * @tparam F A closure type - * @param f A closure of vars - * @return A new std::vector of vars + * @param f A closure containing vars + * @return A new closure containing vars */ template * = nullptr, require_not_arithmetic_t>* = nullptr> diff --git a/test/unit/math/rev/functor/integrate_ode_rk45_rev_test.cpp b/test/unit/math/rev/functor/integrate_ode_rk45_rev_test.cpp index 20e2c88840f..af4f8006d2d 100644 --- a/test/unit/math/rev/functor/integrate_ode_rk45_rev_test.cpp +++ b/test/unit/math/rev/functor/integrate_ode_rk45_rev_test.cpp @@ -291,7 +291,7 @@ TEST(StanAgradRevOde_integrate_ode_rk45, closure) { std::vector> res; res.push_back(y_in.at(1)); - res.push_back(-y_in.at(0) - theta.at(0) * y_in.at(1)); + res.push_back(-y_in.at(0) - (a + theta.at(0)) * y_in.at(1)); return res; }, @@ -311,15 +311,19 @@ TEST(StanAgradRevOde_integrate_ode_rk45, closure) { auto test_ad = [&res, &t0v, &a0, &theta, &x, &x_int, &msgs]() { res[0][0].grad(); EXPECT_FLOAT_EQ(t0v.adj(), -0.66360742442816977871); + EXPECT_FLOAT_EQ(a0.adj(), -0.80045092); stan::math::set_zero_all_adjoints(); res[0][1].grad(); EXPECT_FLOAT_EQ(t0v.adj(), 0.23542843380353062344); + EXPECT_FLOAT_EQ(a0.adj(), -1.5989847); stan::math::set_zero_all_adjoints(); res[1][0].grad(); EXPECT_FLOAT_EQ(t0v.adj(), -0.2464078910913158893); + EXPECT_FLOAT_EQ(a0.adj(), 1.904654); stan::math::set_zero_all_adjoints(); res[1][1].grad(); EXPECT_FLOAT_EQ(t0v.adj(), -0.38494826636037426937); + EXPECT_FLOAT_EQ(a0.adj(), -1.3748885); stan::math::set_zero_all_adjoints(); }; res = integrate_ode_rk45(ode, y0, t0v, ts, theta, x, x_int, nullptr, 1e-10, From bbabc920ad78d258998bd4a4fac1fbe021491c8e Mon Sep 17 00:00:00 2001 From: Niko Huurre Date: Sat, 20 Mar 2021 14:26:25 +0200 Subject: [PATCH 06/25] fix reduce_sum off-by-one index --- stan/math/prim/functor/closure_adapter.hpp | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/stan/math/prim/functor/closure_adapter.hpp b/stan/math/prim/functor/closure_adapter.hpp index ad2d37ece49..d5dad08534b 100644 --- a/stan/math/prim/functor/closure_adapter.hpp +++ b/stan/math/prim/functor/closure_adapter.hpp @@ -1,6 +1,7 @@ #ifndef STAN_MATH_PRIM_FUNCTOR_CLOSURE_ADAPTER_HPP #define STAN_MATH_PRIM_FUNCTOR_CLOSURE_ADAPTER_HPP +#include #include #include #include @@ -156,7 +157,8 @@ struct empty_closure_lp { }; /** - * Higher-order functor suitable for calling a closure inside variadic ODE solvers. + * Higher-order functor suitable for calling a closure inside variadic ODE + * solvers. */ struct ode_closure_adapter { template @@ -253,7 +255,8 @@ struct reduce_sum_closure_adapter { auto operator()(const std::vector& sub_slice, std::size_t start, std::size_t end, std::ostream* msgs, const F& f, Args... args) const { - return f(msgs, sub_slice, start, end, args...); + return f(msgs, sub_slice, start + error_index::value, + end + error_index::value, args...); } }; From 76a8991e27558f6bad5267488ef2b9b0ed58bff5 Mon Sep 17 00:00:00 2001 From: Niko Huurre Date: Sat, 20 Mar 2021 15:19:44 +0200 Subject: [PATCH 07/25] closure support for integrate_1d --- stan/math/prim/functor/integrate_1d.hpp | 14 +++++++++++++- stan/math/prim/functor/integrate_1d_adapter.hpp | 15 +++++++++++++++ stan/math/rev/functor/integrate_1d.hpp | 12 ++++++++++++ 3 files changed, 40 insertions(+), 1 deletion(-) diff --git a/stan/math/prim/functor/integrate_1d.hpp b/stan/math/prim/functor/integrate_1d.hpp index 14cbf015be8..97c6415c616 100644 --- a/stan/math/prim/functor/integrate_1d.hpp +++ b/stan/math/prim/functor/integrate_1d.hpp @@ -211,7 +211,7 @@ inline double integrate_1d_impl(const F& f, double a, double b, * @param relative_tolerance tolerance passed to Boost quadrature * @return numeric integral of function f */ -template +template * = nullptr> inline double integrate_1d(const F& f, double a, double b, const std::vector& theta, const std::vector& x_r, @@ -222,6 +222,18 @@ inline double integrate_1d(const F& f, double a, double b, msgs, theta, x_r, x_i); } +template * = nullptr, + require_arithmetic_t>* = nullptr> +inline double integrate_1d(const F& f, double a, double b, + const std::vector& theta, + const std::vector& x_r, + const std::vector& x_i, std::ostream* msgs, + const double relative_tolerance + = std::sqrt(EPSILON)) { + return integrate_1d_impl(integrate_1d_closure_adapter(), a, b, + relative_tolerance, msgs, f, theta, x_r, x_i); +} + } // namespace math } // namespace stan diff --git a/stan/math/prim/functor/integrate_1d_adapter.hpp b/stan/math/prim/functor/integrate_1d_adapter.hpp index ecfcaaaa9a6..60a7ebb3fc4 100644 --- a/stan/math/prim/functor/integrate_1d_adapter.hpp +++ b/stan/math/prim/functor/integrate_1d_adapter.hpp @@ -25,4 +25,19 @@ struct integrate_1d_adapter { } }; +/** + * Call a closure object from integrate_1d + */ +struct integrate_1d_closure_adapter { + explicit integrate_1d_closure_adapter() {} + + template + auto operator()(const T_a& x, const T_b& xc, std::ostream* msgs, const F& f, + const std::vector& theta, + const std::vector& x_r, + const std::vector& x_i) const { + return f(msgs, x, xc, theta, x_r, x_i); + } +}; + #endif diff --git a/stan/math/rev/functor/integrate_1d.hpp b/stan/math/rev/functor/integrate_1d.hpp index e0568ea8ed3..965f69e7b9d 100644 --- a/stan/math/rev/functor/integrate_1d.hpp +++ b/stan/math/rev/functor/integrate_1d.hpp @@ -221,6 +221,7 @@ inline return_type_t integrate_1d_impl( * @return numeric integral of function f */ template , typename = require_any_var_t> inline return_type_t integrate_1d( const F &f, const T_a &a, const T_b &b, const std::vector &theta, @@ -230,6 +231,17 @@ inline return_type_t integrate_1d( msgs, theta, x_r, x_i); } +template , + typename = require_any_var_t, T_a, T_b, T_theta>> +inline return_type_t integrate_1d( + const F &f, const T_a &a, const T_b &b, const std::vector &theta, + const std::vector &x_r, const std::vector &x_i, + std::ostream *msgs, const double relative_tolerance = std::sqrt(EPSILON)) { + return integrate_1d_impl(integrate_1d_closure_adapter(), a, b, + relative_tolerance, msgs, f, theta, x_r, x_i); +} + } // namespace math } // namespace stan From 990c07095198236e1b6db6689cbab61759c02e4e Mon Sep 17 00:00:00 2001 From: Niko Huurre Date: Tue, 15 Jun 2021 12:35:12 +0300 Subject: [PATCH 08/25] move ode_closure_adapter out of internal namespace --- stan/math/prim/functor/closure_adapter.hpp | 4 ++-- stan/math/prim/functor/integrate_ode_rk45.hpp | 4 ++-- ...grate_ode_std_vector_interface_adapter.hpp | 21 +++++++++++++++++-- stan/math/prim/functor/ode_ckrk.hpp | 16 -------------- stan/math/prim/functor/ode_rk45.hpp | 16 -------------- stan/math/rev/functor/integrate_ode_adams.hpp | 4 ++-- stan/math/rev/functor/integrate_ode_bdf.hpp | 4 ++-- stan/math/rev/functor/ode_adams.hpp | 16 -------------- stan/math/rev/functor/ode_bdf.hpp | 16 -------------- 9 files changed, 27 insertions(+), 74 deletions(-) diff --git a/stan/math/prim/functor/closure_adapter.hpp b/stan/math/prim/functor/closure_adapter.hpp index d5dad08534b..544b82ab7f6 100644 --- a/stan/math/prim/functor/closure_adapter.hpp +++ b/stan/math/prim/functor/closure_adapter.hpp @@ -156,6 +156,8 @@ struct empty_closure_lp { } }; +} // namespace internal + /** * Higher-order functor suitable for calling a closure inside variadic ODE * solvers. @@ -168,8 +170,6 @@ struct ode_closure_adapter { } }; -} // namespace internal - /** * Create a closure object from a callable. */ diff --git a/stan/math/prim/functor/integrate_ode_rk45.hpp b/stan/math/prim/functor/integrate_ode_rk45.hpp index afc3fdbb1cb..31b691d1cf3 100644 --- a/stan/math/prim/functor/integrate_ode_rk45.hpp +++ b/stan/math/prim/functor/integrate_ode_rk45.hpp @@ -22,9 +22,9 @@ inline auto integrate_ode_rk45( std::ostream* msgs = nullptr, double relative_tolerance = 1e-6, double absolute_tolerance = 1e-6, int max_num_steps = 1e6) { internal::integrate_ode_std_vector_interface_adapter f_adapted(f); - auto y = ode_rk45_tol_impl("integrate_ode_rk45", f_adapted, to_vector(y0), t0, + auto y = ode_rk45_tol_impl("integrate_ode_rk45", ode_closure_adapter(), to_vector(y0), t0, ts, relative_tolerance, absolute_tolerance, - max_num_steps, msgs, theta, x, x_int); + max_num_steps, msgs, f_adapted, theta, x, x_int); std::vector>> y_converted; diff --git a/stan/math/prim/functor/integrate_ode_std_vector_interface_adapter.hpp b/stan/math/prim/functor/integrate_ode_std_vector_interface_adapter.hpp index a50794a1680..377de796dce 100644 --- a/stan/math/prim/functor/integrate_ode_std_vector_interface_adapter.hpp +++ b/stan/math/prim/functor/integrate_ode_std_vector_interface_adapter.hpp @@ -25,17 +25,34 @@ struct integrate_ode_std_vector_interface_adapter_impl; template struct integrate_ode_std_vector_interface_adapter_impl { + using captured_scalar_t__ = double; + using ValueOf__ = integrate_ode_std_vector_interface_adapter_impl< + false, F, false>; + using CopyOf__ = integrate_ode_std_vector_interface_adapter_impl< + false, F, false>; const F& f_; explicit integrate_ode_std_vector_interface_adapter_impl(const F& f) : f_(f) {} template - auto operator()(const T0& t, const Eigen::Matrix& y, - std::ostream* msgs, const std::vector& theta, + auto operator()(std::ostream* msgs, const T0& t, + const Eigen::Matrix& y, + const std::vector& theta, const std::vector& x, const std::vector& x_int) const { return to_vector(f_(t, to_array_1d(y), theta, x, x_int, msgs)); } + size_t count_vars__() const { } + auto value_of__() const { return ValueOf__(f_); } + auto deep_copy_vars__() const { return CopyOf__(f_); } + void zero_adjoints__() { } + double* accumulate_adjoints__(double* dest) const { + return dest; + } + template + Vari** save_varis__(Vari** dest) const { + return dest; + } }; template diff --git a/stan/math/prim/functor/ode_ckrk.hpp b/stan/math/prim/functor/ode_ckrk.hpp index dd7062ce8b4..44d455ea520 100644 --- a/stan/math/prim/functor/ode_ckrk.hpp +++ b/stan/math/prim/functor/ode_ckrk.hpp @@ -158,22 +158,6 @@ ode_ckrk_tol_impl(const char* function_name, const F& f, const T_y0& y0_arg, return y; } -template * = nullptr, - require_stan_closure_t* = nullptr> -std::vector, - Eigen::Dynamic, 1>> -ode_ckrk_tol_impl(const char* function_name, const F& f, const T_y0& y0_arg, - T_t0 t0, const std::vector& ts, - double relative_tolerance, double absolute_tolerance, - long int max_num_steps, // NOLINT(runtime/int) - std::ostream* msgs, const Args&... args) { - internal::ode_closure_adapter f_adapter; - return ode_ckrk_tol_impl(function_name, f_adapter, y0_arg, t0, ts, - relative_tolerance, absolute_tolerance, - max_num_steps, msgs, f, args...); -} - /** * Solve the ODE initial value problem y' = f(t, y), y(t0) = y0 at a set of * times, { t1, t2, t3, ... } using Boost's Cash-Karp solver. diff --git a/stan/math/prim/functor/ode_rk45.hpp b/stan/math/prim/functor/ode_rk45.hpp index b60a18ac3a5..068a267e22d 100644 --- a/stan/math/prim/functor/ode_rk45.hpp +++ b/stan/math/prim/functor/ode_rk45.hpp @@ -160,22 +160,6 @@ ode_rk45_tol_impl(const char* function_name, const F& f, const T_y0& y0_arg, return y; } -template * = nullptr, - require_stan_closure_t* = nullptr> -std::vector, - Eigen::Dynamic, 1>> -ode_rk45_tol_impl(const char* function_name, const F& f, const T_y0& y0_arg, - T_t0 t0, const std::vector& ts, - double relative_tolerance, double absolute_tolerance, - long int max_num_steps, // NOLINT(runtime/int) - std::ostream* msgs, const Args&... args) { - internal::ode_closure_adapter f_adapter; - return ode_rk45_tol_impl(function_name, f_adapter, y0_arg, t0, ts, - relative_tolerance, absolute_tolerance, - max_num_steps, msgs, f, args...); -} - /** * Solve the ODE initial value problem y' = f(t, y), y(t0) = y0 at a set of * times, { t1, t2, t3, ... } using the non-stiff Runge-Kutta 45 solver in diff --git a/stan/math/rev/functor/integrate_ode_adams.hpp b/stan/math/rev/functor/integrate_ode_adams.hpp index 8abcf660173..ca48b2cbfe2 100644 --- a/stan/math/rev/functor/integrate_ode_adams.hpp +++ b/stan/math/rev/functor/integrate_ode_adams.hpp @@ -25,9 +25,9 @@ integrate_ode_adams(const F& f, const std::vector& y0, const T_t0& t0, double absolute_tolerance = 1e-10, long int max_num_steps = 1e8) { // NOLINT(runtime/int) internal::integrate_ode_std_vector_interface_adapter f_adapted(f); - auto y = ode_adams_tol_impl("integrate_ode_adams", f_adapted, to_vector(y0), + auto y = ode_adams_tol_impl("integrate_ode_adams", ode_closure_adapter(), to_vector(y0), t0, ts, relative_tolerance, absolute_tolerance, - max_num_steps, msgs, theta, x, x_int); + max_num_steps, msgs, f_adapted, theta, x, x_int); std::vector>> y_converted; diff --git a/stan/math/rev/functor/integrate_ode_bdf.hpp b/stan/math/rev/functor/integrate_ode_bdf.hpp index 03bcda04234..a7f113e201e 100644 --- a/stan/math/rev/functor/integrate_ode_bdf.hpp +++ b/stan/math/rev/functor/integrate_ode_bdf.hpp @@ -25,9 +25,9 @@ integrate_ode_bdf(const F& f, const std::vector& y0, const T_t0& t0, double absolute_tolerance = 1e-10, long int max_num_steps = 1e8) { // NOLINT(runtime/int) internal::integrate_ode_std_vector_interface_adapter f_adapted(f); - auto y = ode_bdf_tol_impl("integrate_ode_bdf", f_adapted, to_vector(y0), t0, + auto y = ode_bdf_tol_impl("integrate_ode_bdf", ode_closure_adapter(), to_vector(y0), t0, ts, relative_tolerance, absolute_tolerance, - max_num_steps, msgs, theta, x, x_int); + max_num_steps, msgs, f_adapted, theta, x, x_int); std::vector>> y_converted; diff --git a/stan/math/rev/functor/ode_adams.hpp b/stan/math/rev/functor/ode_adams.hpp index 9ebeb026f94..d6b64eeb45a 100644 --- a/stan/math/rev/functor/ode_adams.hpp +++ b/stan/math/rev/functor/ode_adams.hpp @@ -67,22 +67,6 @@ ode_adams_tol_impl(const char* function_name, const F& f, const T_y0& y0, args_ref_tuple); } -template * = nullptr, - require_stan_closure_t* = nullptr> -std::vector, - Eigen::Dynamic, 1>> -ode_adams_tol_impl(const char* function_name, const F& f, const T_y0& y0, - const T_t0& t0, const std::vector& ts, - double relative_tolerance, double absolute_tolerance, - long int max_num_steps, // NOLINT(runtime/int) - std::ostream* msgs, const T_Args&... args) { - internal::ode_closure_adapter f_adapter; - return ode_adams_tol_impl(function_name, f_adapter, y0, t0, ts, - relative_tolerance, absolute_tolerance, - max_num_steps, msgs, f, args...); -} - /** * Solve the ODE initial value problem y' = f(t, y), y(t0) = y0 at a set of * times, { t1, t2, t3, ... } using the non-stiff Adams-Moulton solver from diff --git a/stan/math/rev/functor/ode_bdf.hpp b/stan/math/rev/functor/ode_bdf.hpp index 8fee90cfdbd..cf53b1f54e9 100644 --- a/stan/math/rev/functor/ode_bdf.hpp +++ b/stan/math/rev/functor/ode_bdf.hpp @@ -68,22 +68,6 @@ ode_bdf_tol_impl(const char* function_name, const F& f, const T_y0& y0, args_ref_tuple); } -template * = nullptr, - require_stan_closure_t* = nullptr> -std::vector, - Eigen::Dynamic, 1>> -ode_bdf_tol_impl(const char* function_name, const F& f, const T_y0& y0, - const T_t0& t0, const std::vector& ts, - double relative_tolerance, double absolute_tolerance, - long int max_num_steps, // NOLINT(runtime/int) - std::ostream* msgs, const T_Args&... args) { - internal::ode_closure_adapter f_adapter; - return ode_bdf_tol_impl(function_name, f_adapter, y0, t0, ts, - relative_tolerance, absolute_tolerance, max_num_steps, - msgs, f, args...); -} - /** * Solve the ODE initial value problem y' = f(t, y), y(t0) = y0 at a set of * times, { t1, t2, t3, ... } using the stiff backward differentiation formula From cbf48fadabb5cceb8fa78b1a1d42e9fd6607bb43 Mon Sep 17 00:00:00 2001 From: Niko Huurre Date: Tue, 15 Jun 2021 19:14:40 +0300 Subject: [PATCH 09/25] refactor tests --- stan/math/prim/functor/closure_adapter.hpp | 6 +- .../rev/functor/closure_ode_typed_test.cpp | 80 ++ .../math/rev/functor/cos_ode_typed_test.cpp | 1 - ..._ode_std_vector_interface_adapter_test.cpp | 6 +- .../rev/functor/test_fixture_ode_closure.hpp | 1162 +++++++++++++++++ .../functor/test_fixture_ode_cos_scalar.hpp | 21 - 6 files changed, 1248 insertions(+), 28 deletions(-) create mode 100644 test/unit/math/rev/functor/closure_ode_typed_test.cpp create mode 100644 test/unit/math/rev/functor/test_fixture_ode_closure.hpp diff --git a/stan/math/prim/functor/closure_adapter.hpp b/stan/math/prim/functor/closure_adapter.hpp index 544b82ab7f6..2d73223ee7f 100644 --- a/stan/math/prim/functor/closure_adapter.hpp +++ b/stan/math/prim/functor/closure_adapter.hpp @@ -45,7 +45,7 @@ template struct one_arg_closure { using captured_scalar_t__ = return_type_t; using ValueOf__ - = one_arg_closure()))>; + = one_arg_closure())))>; using CopyOf__ = one_arg_closure; F f_; capture_type_t s_; @@ -57,9 +57,9 @@ struct one_arg_closure { return f_(s_, args..., msgs); } size_t count_vars__() const { return count_vars(s_); } - auto value_of__() const { return ValueOf__(f_, value_of(s_)); } + auto value_of__() const { return ValueOf__(f_, eval(value_of(s_))); } auto copy_of__() const { return CopyOf__(f_, s_); } - auto deep_copy_vars__() const { return CopyOf__(f_, deep_copy_vars(s_)); } + auto deep_copy_vars__() const { return CopyOf__(f_, eval(deep_copy_vars(s_))); } void zero_adjoints__() { zero_adjoints(s_); } double* accumulate_adjoints__(double* dest) const { return accumulate_adjoints(dest, s_); diff --git a/test/unit/math/rev/functor/closure_ode_typed_test.cpp b/test/unit/math/rev/functor/closure_ode_typed_test.cpp new file mode 100644 index 00000000000..dcd4fecc8e1 --- /dev/null +++ b/test/unit/math/rev/functor/closure_ode_typed_test.cpp @@ -0,0 +1,80 @@ +#include +#include +#include +#include +#include +#include + +/** + * + * Use same solver functor type for both w & w/o tolerance control + */ +template +using ode_test_tuple = std::tuple; + +/** + * Outer product of test types + */ +using closure_test_types = boost::mp11::mp_product< + ode_test_tuple, ::testing::Types >; + +TYPED_TEST_SUITE_P(closure_test); +TYPED_TEST_P(closure_test, y0_error) { + this->test_y0_error(); + this->test_y0_error_with_tol(); +} +TYPED_TEST_P(closure_test, t0_error) { + this->test_t0_error(); + this->test_t0_error_with_tol(); +} +TYPED_TEST_P(closure_test, ts_error) { + this->test_ts_error(); + this->test_ts_error_with_tol(); +} +TYPED_TEST_P(closure_test, two_arg_error) { + this->test_two_arg_error(); + this->test_two_arg_error_with_tol(); +} +TYPED_TEST_P(closure_test, tol_error) { + this->test_rtol_error(); + this->test_atol_error(); + this->test_max_num_step_error(); + this->test_too_much_work(); +} +TYPED_TEST_P(closure_test, value) { this->test_value(); } +TYPED_TEST_P(closure_test, grad) { + this->test_grad_t0(); + this->test_grad_ts(); + this->test_grad_ts_repeat(); + this->test_scalar_arg(); + this->test_std_vector_arg(); + this->test_vector_arg(); + this->test_row_vector_arg(); + this->test_matrix_arg(); + this->test_scalar_std_vector_args(); + this->test_std_vector_std_vector_args(); + this->test_std_vector_vector_args(); + this->test_std_vector_row_vector_args(); + this->test_std_vector_matrix_args(); + this->test_arg_combos_test(); +} +TYPED_TEST_P(closure_test, tol_grad) { + this->test_tol_t0(); + this->test_tol_ts(); + this->test_tol_ts_repeat(); + this->test_tol_scalar_arg(); + this->test_tol_scalar_arg_multi_time(); + this->test_tol_std_vector_arg(); + this->test_tol_vector_arg(); + this->test_tol_row_vector_arg(); + this->test_tol_matrix_arg(); + this->test_tol_scalar_std_vector_args(); + this->test_tol_std_vector_std_vector_args(); + this->test_tol_std_vector_vector_args(); + this->test_tol_std_vector_row_vector_args(); + this->test_tol_std_vector_matrix_args(); +} +REGISTER_TYPED_TEST_SUITE_P(closure_test, y0_error, t0_error, ts_error, + two_arg_error, tol_error, value, grad, tol_grad); +INSTANTIATE_TYPED_TEST_SUITE_P(StanOde, closure_test, closure_test_types); diff --git a/test/unit/math/rev/functor/cos_ode_typed_test.cpp b/test/unit/math/rev/functor/cos_ode_typed_test.cpp index 7b70a40602d..1327745c2e3 100644 --- a/test/unit/math/rev/functor/cos_ode_typed_test.cpp +++ b/test/unit/math/rev/functor/cos_ode_typed_test.cpp @@ -69,7 +69,6 @@ TYPED_TEST_P(cos_arg_test, grad) { this->test_std_vector_vector_args(); this->test_std_vector_row_vector_args(); this->test_std_vector_matrix_args(); - this->test_closure(); this->test_arg_combos_test(); } TYPED_TEST_P(cos_arg_test, tol_grad) { diff --git a/test/unit/math/rev/functor/integrate_ode_std_vector_interface_adapter_test.cpp b/test/unit/math/rev/functor/integrate_ode_std_vector_interface_adapter_test.cpp index f8611f7f57a..effa39ac332 100644 --- a/test/unit/math/rev/functor/integrate_ode_std_vector_interface_adapter_test.cpp +++ b/test/unit/math/rev/functor/integrate_ode_std_vector_interface_adapter_test.cpp @@ -22,7 +22,7 @@ TEST(StanMathRev, vd) { Eigen::Matrix out1 = stan::math::to_vector(harm_osc(t, y, theta, x, x_int, nullptr)); Eigen::Matrix out2 - = harm_osc_adapted(t, stan::math::to_vector(y), nullptr, theta, x, x_int); + = harm_osc_adapted(nullptr, t, stan::math::to_vector(y), theta, x, x_int); stan::math::sum(out1).grad(); Eigen::VectorXd adjs1(theta.size()); @@ -58,7 +58,7 @@ TEST(StanMathRev, dv) { Eigen::Matrix out1 = stan::math::to_vector(harm_osc(t, y, theta, x, x_int, nullptr)); Eigen::Matrix out2 - = harm_osc_adapted(t, stan::math::to_vector(y), nullptr, theta, x, x_int); + = harm_osc_adapted(nullptr, t, stan::math::to_vector(y), theta, x, x_int); stan::math::sum(out1).grad(); Eigen::VectorXd adjs1(y.size()); @@ -94,7 +94,7 @@ TEST(StanMathRev, vv) { Eigen::Matrix out1 = stan::math::to_vector(harm_osc(t, y, theta, x, x_int, nullptr)); Eigen::Matrix out2 - = harm_osc_adapted(t, stan::math::to_vector(y), nullptr, theta, x, x_int); + = harm_osc_adapted(nullptr, t, stan::math::to_vector(y), theta, x, x_int); stan::math::sum(out1).grad(); Eigen::VectorXd adjs_theta_1(theta.size()); diff --git a/test/unit/math/rev/functor/test_fixture_ode_closure.hpp b/test/unit/math/rev/functor/test_fixture_ode_closure.hpp new file mode 100644 index 00000000000..8dac5742d28 --- /dev/null +++ b/test/unit/math/rev/functor/test_fixture_ode_closure.hpp @@ -0,0 +1,1162 @@ +#ifndef STAN_MATH_TEST_FIXTURE_ODE_CLOSURE_HPP +#define STAN_MATH_TEST_FIXTURE_ODE_CLOSURE_HPP + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +struct closure_ode_base { + Eigen::VectorXd y0; + double t0; + std::vector ts; + double a; + double rtol; + double atol; + int max_num_step; + + closure_ode_base() + : y0(1), + t0(0.0), + ts{0.45, 1.1}, + a(1.5), + rtol(1.e-10), + atol(1.e-10), + max_num_step(100000) { + y0[0] = 0.0; + } +}; + +/** + * Inheriting base type, various fixtures differs by the type of ODE + * functor used in apply_solver calls, intended for + * different kind of tests. + * + */ +template +struct closure_test : public closure_ode_base, + public ODETestFixture> { + closure_test() : closure_ode_base() {} + + Eigen::VectorXd init() { return y0; } + std::vector param() { return {a}; } + + auto apply_solver() { + std::tuple_element_t<0, T> sol; + return sol(stan::math::ode_closure_adapter(), y0, t0, ts, nullptr, + stan::math::from_lambda([](auto& a_, auto& t, auto& y, std::ostream* msgs) { + return stan::test::CosArg1()(t, y, msgs, a_); + }, a)); + } + + template + auto apply_solver(Eigen::Matrix& init, std::vector& va) { + std::tuple_element_t<0, T> sol; + return sol(stan::math::ode_closure_adapter(), init, t0, ts, nullptr, + stan::math::from_lambda([](auto& a, auto& t, auto& y, std::ostream* msgs) { + return stan::test::CosArg1()(t, y, msgs, a); + }, va)); + } + + template + auto apply_solver_ts(const std::vector& ts_) { + std::tuple_element_t<0, T> sol; + return sol(stan::math::ode_closure_adapter(), y0, t0, ts_, nullptr, + stan::math::from_lambda([](auto& a_, auto& t, auto& y, std::ostream* msgs) { + return stan::test::CosArg1()(t, y, msgs, a_); + }, a)); + } + + template + auto apply_solver_ts(const std::vector& ts_, const a_type& arg) { + std::tuple_element_t<0, T> sol; + return sol(stan::math::ode_closure_adapter(), y0, t0, ts_, nullptr, + stan::math::from_lambda([](auto& a_, auto& t, auto& y, std::ostream* msgs) { + return stan::test::CosArg1()(t, y, msgs, a_); + }, arg)); + } + + template + auto apply_solver_ts_tol(const std::vector& ts_, double rtol, double atol, int max_num_steps, const a_type& a_) { + std::tuple_element_t<1, T> sol; + return sol(stan::math::ode_closure_adapter(), y0, t0, ts_, rtol, atol, max_num_steps, nullptr, + stan::math::from_lambda([](auto& a, auto& t, auto& y, std::ostream* msgs) { + return stan::test::CosArg1()(t, y, msgs, a); + }, a_)); + } + + template + auto apply_solver_t0(const T0& t0_) { + std::tuple_element_t<0, T> sol; + return sol(stan::math::ode_closure_adapter(), y0, t0_, ts, nullptr, + stan::math::from_lambda([](auto& a_, auto& t, auto& y, std::ostream* msgs) { + return stan::test::CosArg1()(t, y, msgs, a_); + }, a)); + } + + template + auto apply_solver_t0_tol(const T0& t0_, double rtol, double atol, int max_num_steps, const a_type& a_) { + std::tuple_element_t<1, T> sol; + return sol(stan::math::ode_closure_adapter(), y0, t0_, ts, rtol, atol, max_num_steps, nullptr, + stan::math::from_lambda([](auto& a, auto& t, auto& y, std::ostream* msgs) { + return stan::test::CosArg1()(t, y, msgs, a); + }, a_)); + } + + auto apply_solver_tol() { + std::tuple_element_t<1, T> sol; + return sol(stan::math::ode_closure_adapter(), y0, t0, ts, rtol, atol, max_num_step, + nullptr, + stan::math::from_lambda([](auto& a_, auto& t, auto& y, std::ostream* msgs) { + return stan::test::CosArg1()(t, y, msgs, a_); + }, a)); + } + + template + auto apply_solver_arg(a_type const& a_) { + std::tuple_element_t<0, T> sol; + return sol(stan::math::ode_closure_adapter(), y0, t0, ts, nullptr, + stan::math::from_lambda([](auto& a, auto& t, auto& y, std::ostream* msgs) { + return stan::test::CosArg1()(t, y, msgs, a); + }, a_)); + } + + template + auto apply_solver_arg_tol(a_type const& a_) { + std::tuple_element_t<1, T> sol; + return sol(stan::math::ode_closure_adapter(), y0, t0, ts, rtol, atol, max_num_step, nullptr, + stan::math::from_lambda([](auto& a, auto& t, auto& y, std::ostream* msgs) { + return stan::test::CosArg1()(t, y, msgs, a); + }, a_)); + } + + template + auto apply_solver_arg(a_type const& a_, b_type const& b_) { + std::tuple_element_t<0, T> sol; + return sol(stan::math::ode_closure_adapter(), y0, t0, ts, nullptr, + stan::math::from_lambda([](auto& a, auto& t, auto& y, auto& b, std::ostream* msgs) { + return stan::test::Cos2Arg()(t, y, msgs, a, b); + }, a_), b_); + } + + template + auto apply_solver_arg_tol(a_type const& a_, b_type const& b_) { + std::tuple_element_t<1, T> sol; + return sol(stan::math::ode_closure_adapter(), y0, t0, ts, rtol, atol, max_num_step, nullptr, + stan::math::from_lambda([](auto& a, auto& t, auto& y, auto& b, std::ostream* msgs) { + return stan::test::Cos2Arg()(t, y, msgs, a, b); + }, a_), b_); + } + + void test_y0_error() { + y0 = Eigen::VectorXd::Zero(1); + ASSERT_NO_THROW(apply_solver()); + + y0[0] = stan::math::INFTY; + EXPECT_THROW(apply_solver(), std::domain_error); + + y0[0] = stan::math::NOT_A_NUMBER; + EXPECT_THROW(apply_solver(), std::domain_error); + + y0 = Eigen::VectorXd(); + EXPECT_THROW(apply_solver(), std::invalid_argument); + } + + void test_y0_error_with_tol() { + y0 = Eigen::VectorXd::Zero(1); + ASSERT_NO_THROW(apply_solver_tol()); + + y0[0] = stan::math::INFTY; + EXPECT_THROW(apply_solver_tol(), std::domain_error); + + y0[0] = stan::math::NOT_A_NUMBER; + EXPECT_THROW(apply_solver_tol(), std::domain_error); + + y0 = Eigen::VectorXd(); + EXPECT_THROW(apply_solver_tol(), std::invalid_argument); + } + + void test_t0_error() { + t0 = stan::math::INFTY; + EXPECT_THROW(apply_solver(), std::domain_error); + + t0 = stan::math::NOT_A_NUMBER; + EXPECT_THROW(apply_solver(), std::domain_error); + } + + void test_t0_error_with_tol() { + t0 = stan::math::INFTY; + EXPECT_THROW(apply_solver_tol(), std::domain_error); + + t0 = stan::math::NOT_A_NUMBER; + EXPECT_THROW(apply_solver_tol(), std::domain_error); + } + + void test_ts_error() { + std::vector ts_repeat = {0.45, 0.45}; + std::vector ts_lots = {0.45, 0.45, 1.1, 1.1, 2.0}; + std::vector ts_empty = {}; + std::vector ts_early = {-0.45, 0.2}; + std::vector ts_decreasing = {0.45, 0.2}; + std::vector tsinf = {stan::math::INFTY, 1.1}; + std::vector tsNaN = {0.45, stan::math::NOT_A_NUMBER}; + + std::vector out; + EXPECT_NO_THROW(out = apply_solver()); + EXPECT_EQ(out.size(), ts.size()); + + ts = ts_repeat; + EXPECT_NO_THROW(out = apply_solver()); + EXPECT_EQ(out.size(), ts_repeat.size()); + EXPECT_MATRIX_FLOAT_EQ(out[0], out[1]); + + ts = ts_lots; + EXPECT_NO_THROW(out = apply_solver()); + EXPECT_EQ(out.size(), ts_lots.size()); + + ts = ts_empty; + EXPECT_THROW(apply_solver(), std::invalid_argument); + + ts = ts_early; + EXPECT_THROW(apply_solver(), std::domain_error); + + ts = ts_decreasing; + EXPECT_THROW(apply_solver(), std::domain_error); + + ts = tsinf; + EXPECT_THROW(apply_solver(), std::domain_error); + + ts = tsNaN; + EXPECT_THROW(apply_solver(), std::domain_error); + + ts = {0.45, 1.1}; + } + + void test_ts_error_with_tol() { + std::vector ts_repeat = {0.45, 0.45}; + std::vector ts_lots = {0.45, 0.45, 1.1, 1.1, 2.0}; + std::vector ts_empty = {}; + std::vector ts_early = {-0.45, 0.2}; + std::vector ts_decreasing = {0.45, 0.2}; + std::vector tsinf = {stan::math::INFTY, 1.1}; + std::vector tsNaN = {0.45, stan::math::NOT_A_NUMBER}; + + std::vector out; + EXPECT_NO_THROW(out = apply_solver_tol()); + EXPECT_EQ(out.size(), ts.size()); + + ts = ts_repeat; + EXPECT_NO_THROW(out = apply_solver_tol()); + EXPECT_EQ(out.size(), ts_repeat.size()); + EXPECT_MATRIX_FLOAT_EQ(out[0], out[1]); + + ts = ts_lots; + EXPECT_NO_THROW(out = apply_solver_tol()); + EXPECT_EQ(out.size(), ts_lots.size()); + + ts = ts_empty; + EXPECT_THROW(apply_solver_tol(), std::invalid_argument); + + ts = ts_early; + EXPECT_THROW(apply_solver_tol(), std::domain_error); + + ts = ts_decreasing; + EXPECT_THROW(apply_solver_tol(), std::domain_error); + + ts = tsinf; + EXPECT_THROW(apply_solver_tol(), std::domain_error); + + ts = tsNaN; + EXPECT_THROW(apply_solver_tol(), std::domain_error); + + ts = {0.45, 1.1}; + } + + void test_two_arg_error() { + a = 1.5; + double ainf = stan::math::INFTY; + double aNaN = stan::math::NOT_A_NUMBER; + + std::vector va = {a}; + std::vector vainf = {ainf}; + std::vector vaNaN = {aNaN}; + + Eigen::VectorXd ea(1); + ea << a; + Eigen::VectorXd eainf(1); + eainf << ainf; + Eigen::VectorXd eaNaN(1); + eaNaN << aNaN; + + std::vector> vva = {va}; + std::vector> vvainf = {vainf}; + std::vector> vvaNaN = {vaNaN}; + + std::vector vea = {ea}; + std::vector veainf = {eainf}; + std::vector veaNaN = {eaNaN}; + + EXPECT_NO_THROW(apply_solver_arg(a, a)); + + EXPECT_THROW(apply_solver_arg(a, ainf), std::domain_error); + + EXPECT_THROW(apply_solver_arg(a, aNaN), std::domain_error); + + EXPECT_NO_THROW(apply_solver_arg(a, va)); + + EXPECT_THROW(apply_solver_arg(a, vainf), std::domain_error); + + EXPECT_THROW(apply_solver_arg(a, vaNaN), std::domain_error); + + EXPECT_NO_THROW(apply_solver_arg(a, ea)); + + EXPECT_THROW(apply_solver_arg(a, eainf), std::domain_error); + + EXPECT_THROW(apply_solver_arg(a, eaNaN), std::domain_error); + + EXPECT_NO_THROW(apply_solver_arg(a, vva)); + + EXPECT_THROW(apply_solver_arg(a, vvainf), std::domain_error); + + EXPECT_THROW(apply_solver_arg(a, vvaNaN), std::domain_error); + + EXPECT_NO_THROW(apply_solver_arg(a, vea)); + + EXPECT_THROW(apply_solver_arg(a, veainf), std::domain_error); + + EXPECT_THROW(apply_solver_arg(a, veaNaN), std::domain_error); + } + + void test_two_arg_error_with_tol() { + a = 1.5; + double ainf = stan::math::INFTY; + double aNaN = stan::math::NOT_A_NUMBER; + + std::vector va = {a}; + std::vector vainf = {ainf}; + std::vector vaNaN = {aNaN}; + + Eigen::VectorXd ea(1); + ea << a; + Eigen::VectorXd eainf(1); + eainf << ainf; + Eigen::VectorXd eaNaN(1); + eaNaN << aNaN; + + std::vector> vva = {va}; + std::vector> vvainf = {vainf}; + std::vector> vvaNaN = {vaNaN}; + + std::vector vea = {ea}; + std::vector veainf = {eainf}; + std::vector veaNaN = {eaNaN}; + + EXPECT_NO_THROW(apply_solver_arg_tol(a, a)); + + EXPECT_THROW(apply_solver_arg_tol(a, ainf), std::domain_error); + + EXPECT_THROW(apply_solver_arg_tol(a, aNaN), std::domain_error); + + EXPECT_NO_THROW(apply_solver_arg_tol(a, va)); + + EXPECT_THROW(apply_solver_arg_tol(a, vainf), std::domain_error); + + EXPECT_THROW(apply_solver_arg_tol(a, vaNaN), std::domain_error); + + EXPECT_NO_THROW(apply_solver_arg_tol(a, ea)); + + EXPECT_THROW(apply_solver_arg_tol(a, eainf), std::domain_error); + + EXPECT_THROW(apply_solver_arg_tol(a, eaNaN), std::domain_error); + + EXPECT_NO_THROW(apply_solver_arg_tol(a, vva)); + + EXPECT_THROW(apply_solver_arg_tol(a, vvainf), std::domain_error); + + EXPECT_THROW(apply_solver_arg_tol(a, vvaNaN), std::domain_error); + + EXPECT_NO_THROW(apply_solver_arg_tol(a, vea)); + + EXPECT_THROW(apply_solver_arg_tol(a, veainf), std::domain_error); + + EXPECT_THROW(apply_solver_arg_tol(a, veaNaN), std::domain_error); + } + + void test_rtol_error() { + y0 = Eigen::VectorXd::Zero(1); + t0 = 0; + ts = {0.45, 1.1}; + a = 1.5; + + rtol = 1e-6; + atol = 1e-6; + double rtol_negative = -1e-6; + double rtolinf = stan::math::INFTY; + double rtolNaN = stan::math::NOT_A_NUMBER; + + EXPECT_NO_THROW(apply_solver_tol()); + + rtol = rtol_negative; + EXPECT_THROW(apply_solver_tol(), std::domain_error); + + rtol = rtolinf; + EXPECT_THROW(apply_solver_tol(), std::domain_error); + + rtol = rtolNaN; + EXPECT_THROW(apply_solver_tol(), std::domain_error); + } + + void test_atol_error() { + y0 = Eigen::VectorXd::Zero(1); + t0 = 0; + ts = {0.45, 1.1}; + a = 1.5; + + rtol = 1e-6; + atol = 1e-6; + double atol_negative = -1e-6; + double atolinf = stan::math::INFTY; + double atolNaN = stan::math::NOT_A_NUMBER; + + EXPECT_NO_THROW(apply_solver_tol()); + + atol = atol_negative; + EXPECT_THROW(apply_solver_tol(), std::domain_error); + + atol = atolinf; + EXPECT_THROW(apply_solver_tol(), std::domain_error); + + atol = atolNaN; + EXPECT_THROW(apply_solver_tol(), std::domain_error); + } + + void test_max_num_step_error() { + rtol = 1e-6; + atol = 1e-6; + max_num_step = 500; + int max_num_steps_negative = -500; + int max_num_steps_zero = 0; + + EXPECT_NO_THROW(apply_solver_tol()); + + max_num_step = max_num_steps_negative; + EXPECT_THROW(apply_solver_tol(), std::domain_error); + + max_num_step = max_num_steps_zero; + EXPECT_THROW(apply_solver_tol(), std::domain_error); + } + + void test_too_much_work() { + ts[1] = 1e4; + max_num_step = 10; + EXPECT_THROW_MSG(apply_solver_tol(), std::domain_error, + "Failed to integrate to next output time"); + } + + void test_value() { + std::vector res = apply_solver(); + EXPECT_NEAR(res[0][0], 0.4165982112, 1e-5); + EXPECT_NEAR(res[1][0], 0.66457668563, 1e-5); + + std::vector ts_i = {1, 2}; + std::tuple_element_t<0, T> sol; + res = apply_solver_ts(ts_i); + EXPECT_NEAR(res[0][0], 0.6649966577, 1e-5); + EXPECT_NEAR(res[1][0], 0.09408000537, 1e-5); + + int t0_i = 0; + res = apply_solver_t0(t0_i); + EXPECT_NEAR(res[0][0], 0.4165982112, 1e-5); + EXPECT_NEAR(res[1][0], 0.66457668563, 1e-5); + } + + void test_grad_t0() { + stan::math::nested_rev_autodiff nested; + std::tuple_element_t<0, T> sol; + stan::math::var t0v = 0.0; + auto res = apply_solver_t0(t0v); + + res[0][0].grad(); + + EXPECT_NEAR(res[0][0].val(), 0.4165982112, 1e-5); + EXPECT_NEAR(t0v.adj(), -1.0, 1e-5); + + nested.set_zero_all_adjoints(); + + res[1][0].grad(); + + EXPECT_NEAR(res[1][0].val(), 0.66457668563, 1e-5); + EXPECT_NEAR(t0v.adj(), -1.0, 1e-5); + } + + void test_grad_ts() { + stan::math::nested_rev_autodiff nested; + std::tuple_element_t<0, T> sol; + std::vector tsv = {0.45, 1.1}; + auto res = apply_solver_ts(tsv); + + res[0][0].grad(); + + EXPECT_NEAR(res[0][0].val(), 0.4165982112, 1e-5); + EXPECT_NEAR(tsv[0].adj(), 0.78070695113, 1e-5); + nested.set_zero_all_adjoints(); + + res[1][0].grad(); + + EXPECT_NEAR(res[1][0].val(), 0.66457668563, 1e-5); + EXPECT_NEAR(tsv[1].adj(), -0.0791208888, 1e-5); + } + + void test_grad_ts_repeat() { + stan::math::nested_rev_autodiff nested; + std::tuple_element_t<0, T> sol; + std::vector tsv = {0.45, 0.45, 1.1, 1.1}; + auto output = apply_solver_ts(tsv); + + EXPECT_EQ(output.size(), tsv.size()); + + output[0][0].grad(); + + EXPECT_NEAR(output[0][0].val(), 0.4165982112, 1e-5); + EXPECT_NEAR(tsv[0].adj(), 0.78070695113, 1e-5); + nested.set_zero_all_adjoints(); + + output[1][0].grad(); + + EXPECT_NEAR(output[1][0].val(), 0.4165982112, 1e-5); + EXPECT_NEAR(tsv[1].adj(), 0.78070695113, 1e-5); + nested.set_zero_all_adjoints(); + + output[2][0].grad(); + + EXPECT_NEAR(output[2][0].val(), 0.66457668563, 1e-5); + EXPECT_NEAR(tsv[2].adj(), -0.0791208888, 1e-5); + nested.set_zero_all_adjoints(); + + output[3][0].grad(); + EXPECT_NEAR(output[3][0].val(), 0.66457668563, 1e-5); + EXPECT_NEAR(tsv[3].adj(), -0.0791208888, 1e-5); + } + + void test_scalar_arg() { + stan::math::nested_rev_autodiff nested; + std::tuple_element_t<0, T> sol; + stan::math::var av = 1.5; + + { + std::vector ts1{1.1}; + auto output = apply_solver_ts(ts1, av)[0][0]; + + output.grad(); + + EXPECT_NEAR(output.val(), 0.66457668563, 1e-5); + EXPECT_NEAR(av.adj(), -0.50107310888, 1e-5); + nested.set_zero_all_adjoints(); + } + + { + auto output = apply_solver_arg(av); + + output[0](0).grad(); + + EXPECT_NEAR(output[0](0).val(), 0.4165982112, 1e-5); + EXPECT_NEAR(av.adj(), -0.04352005542, 1e-5); + nested.set_zero_all_adjoints(); + + output[1](0).grad(); + + EXPECT_NEAR(output[1](0).val(), 0.66457668563, 1e-5); + EXPECT_NEAR(av.adj(), -0.50107310888, 1e-5); + nested.set_zero_all_adjoints(); + } + } + + void test_std_vector_arg() { + stan::math::nested_rev_autodiff nested; + std::tuple_element_t<0, T> sol; + std::vector av = {1.5}; + var output = apply_solver_arg(av)[1][0]; + + output.grad(); + + EXPECT_NEAR(output.val(), 0.66457668563, 1e-5); + EXPECT_NEAR(av[0].adj(), -0.50107310888, 1e-5); + } + + void test_vector_arg() { + stan::math::nested_rev_autodiff nested; + std::tuple_element_t<0, T> sol; + Eigen::Matrix av(1); + av << 1.5; + + var output = apply_solver_arg(av)[1][0]; + + output.grad(); + + EXPECT_NEAR(output.val(), 0.66457668563, 1e-5); + EXPECT_NEAR(av(0).adj(), -0.50107310888, 1e-5); + } + + void test_row_vector_arg() { + stan::math::nested_rev_autodiff nested; + std::tuple_element_t<0, T> sol; + Eigen::Matrix av(1); + av << 1.5; + + var output = apply_solver_arg(av)[1][0]; + + output.grad(); + + EXPECT_NEAR(output.val(), 0.66457668563, 1e-5); + EXPECT_NEAR(av(0).adj(), -0.50107310888, 1e-5); + } + + void test_matrix_arg() { + stan::math::nested_rev_autodiff nested; + std::tuple_element_t<0, T> sol; + Eigen::Matrix av(1, 1); + av << 1.5; + + var output = apply_solver_arg(av)[1][0]; + + output.grad(); + + EXPECT_NEAR(output.val(), 0.66457668563, 1e-5); + EXPECT_NEAR(av(0).adj(), -0.50107310888, 1e-5); + } + + void test_scalar_std_vector_args() { + stan::math::nested_rev_autodiff nested; + std::tuple_element_t<0, T> sol; + var a0 = 0.75; + std::vector a1 = {0.75}; + + var output = apply_solver_arg(a0, a1)[1][0]; + + output.grad(); + + EXPECT_NEAR(output.val(), 0.66457668563, 1e-5); + EXPECT_NEAR(a0.adj(), -0.50107310888, 1e-5); + EXPECT_NEAR(a1[0].adj(), -0.50107310888, 1e-5); + } + + void test_std_vector_std_vector_args() { + stan::math::nested_rev_autodiff nested; + std::tuple_element_t<0, T> sol; + var a0 = 1.5; + std::vector a1(1, a0); + std::vector> a2(1, a1); + + var output = apply_solver_arg(a2)[1][0]; + + output.grad(); + + EXPECT_NEAR(output.val(), 0.66457668563, 1e-5); + EXPECT_NEAR(a2[0][0].adj(), -0.50107310888, 1e-5); + } + + void test_std_vector_vector_args() { + stan::math::nested_rev_autodiff nested; + std::tuple_element_t<0, T> sol; + var a0 = 1.5; + Eigen::Matrix a1(1); + a1 << a0; + std::vector> a2(1, a1); + + var output = apply_solver_arg(a2)[1][0]; + + output.grad(); + + EXPECT_NEAR(output.val(), 0.66457668563, 1e-5); + EXPECT_NEAR(a2[0](0).adj(), -0.50107310888, 1e-5); + } + + void test_std_vector_row_vector_args() { + stan::math::nested_rev_autodiff nested; + std::tuple_element_t<0, T> sol; + var a0 = 1.5; + Eigen::Matrix a1(1); + a1 << a0; + std::vector> a2(1, a1); + var output = apply_solver_arg(a2)[1][0]; + + output.grad(); + + EXPECT_NEAR(output.val(), 0.66457668563, 1e-5); + EXPECT_NEAR(a2[0](0).adj(), -0.50107310888, 1e-5); + } + + void test_std_vector_matrix_args() { + stan::math::nested_rev_autodiff nested; + std::tuple_element_t<0, T> sol; + var a0 = 1.5; + Eigen::Matrix a1(1, 1); + a1 << a0; + std::vector> a2(1, a1); + + var output = apply_solver_arg(a2)[1][0]; + + output.grad(); + + EXPECT_NEAR(output.val(), 0.66457668563, 1e-5); + EXPECT_NEAR(a2[0](0).adj(), -0.50107310888, 1e-5); + } + + void test_arg_combos_test() { + stan::math::nested_rev_autodiff nested; + std::tuple_element_t<0, T> sol; + + var t0 = 0.5; + var a = 0.2; + std::vector ts = {1.25}; + Eigen::Matrix y0(1); + y0 << 0.75; + + double t0d = stan::math::value_of(t0); + double ad = stan::math::value_of(a); + std::vector tsd = stan::math::value_of(ts); + Eigen::VectorXd y0d = stan::math::value_of(y0); + + auto check_yT = [&](auto yT) { + EXPECT_NEAR(stan::math::value_of(yT), + y0d(0) * exp(-0.5 * ad * (tsd[0] * tsd[0] - t0d * t0d)), + 1e-5); + }; + + auto check_t0 = [&](var t0) { + EXPECT_NEAR( + t0.adj(), + ad * t0d * y0d(0) * exp(-0.5 * ad * (tsd[0] * tsd[0] - t0d * t0d)), + 1e-5); + }; + + auto check_a = [&](var a) { + EXPECT_NEAR(a.adj(), + -0.5 * (tsd[0] * tsd[0] - t0d * t0d) * y0d(0) + * exp(-0.5 * ad * (tsd[0] * tsd[0] - t0d * t0d)), + 1e-5); + }; + + auto check_ts = [&](std::vector ts) { + EXPECT_NEAR(ts[0].adj(), + -ad * tsd[0] * y0d(0) + * exp(-0.5 * ad * (tsd[0] * tsd[0] - t0d * t0d)), + 1e-5); + }; + + auto check_y0 = [&](Eigen::Matrix y0) { + EXPECT_NEAR(y0(0).adj(), exp(-0.5 * ad * (tsd[0] * tsd[0] - t0d * t0d)), + 1e-5); + }; + + double yT1 = sol(stan::test::ayt(), y0d, t0d, tsd, nullptr, ad)[0](0); + check_yT(yT1); + + var yT2 = sol(stan::test::ayt(), y0d, t0d, tsd, nullptr, a)[0](0); + nested.set_zero_all_adjoints(); + yT2.grad(); + check_yT(yT2); + check_a(a); + + var yT3 = sol(stan::test::ayt(), y0d, t0d, ts, nullptr, ad)[0](0); + nested.set_zero_all_adjoints(); + yT3.grad(); + check_yT(yT3); + check_ts(ts); + + var yT4 = sol(stan::test::ayt(), y0d, t0d, ts, nullptr, a)[0](0); + nested.set_zero_all_adjoints(); + yT4.grad(); + check_yT(yT4); + check_ts(ts); + check_a(a); + + var yT5 = sol(stan::test::ayt(), y0d, t0, tsd, nullptr, ad)[0](0); + nested.set_zero_all_adjoints(); + yT5.grad(); + check_yT(yT5); + check_t0(t0); + + var yT6 = sol(stan::test::ayt(), y0d, t0, tsd, nullptr, a)[0](0); + nested.set_zero_all_adjoints(); + yT6.grad(); + check_yT(yT6); + check_t0(t0); + check_a(a); + + var yT7 = sol(stan::test::ayt(), y0d, t0, ts, nullptr, ad)[0](0); + nested.set_zero_all_adjoints(); + yT7.grad(); + check_yT(yT7); + check_t0(t0); + check_ts(ts); + + var yT8 = sol(stan::test::ayt(), y0d, t0, ts, nullptr, a)[0](0); + nested.set_zero_all_adjoints(); + yT8.grad(); + check_yT(yT8); + check_t0(t0); + check_ts(ts); + check_a(a); + + var yT9 = sol(stan::test::ayt(), y0, t0d, tsd, nullptr, ad)[0](0); + nested.set_zero_all_adjoints(); + yT9.grad(); + check_yT(yT9); + check_y0(y0); + + var yT10 = sol(stan::test::ayt(), y0, t0d, tsd, nullptr, a)[0](0); + nested.set_zero_all_adjoints(); + yT10.grad(); + check_yT(yT10); + check_y0(y0); + check_a(a); + + var yT11 = sol(stan::test::ayt(), y0, t0d, ts, nullptr, ad)[0](0); + nested.set_zero_all_adjoints(); + yT11.grad(); + check_yT(yT11); + check_y0(y0); + check_ts(ts); + + var yT12 = sol(stan::test::ayt(), y0, t0d, ts, nullptr, a)[0](0); + nested.set_zero_all_adjoints(); + yT12.grad(); + check_yT(yT12); + check_y0(y0); + check_ts(ts); + check_a(a); + + var yT13 = sol(stan::test::ayt(), y0, t0, tsd, nullptr, ad)[0](0); + nested.set_zero_all_adjoints(); + yT13.grad(); + check_yT(yT13); + check_y0(y0); + check_t0(t0); + + var yT14 = sol(stan::test::ayt(), y0, t0, tsd, nullptr, a)[0](0); + nested.set_zero_all_adjoints(); + yT14.grad(); + check_yT(yT14); + check_y0(y0); + check_t0(t0); + check_a(a); + + var yT15 = sol(stan::test::ayt(), y0, t0, ts, nullptr, ad)[0](0); + nested.set_zero_all_adjoints(); + yT15.grad(); + check_yT(yT15); + check_y0(y0); + check_t0(t0); + check_ts(ts); + + var yT16 = sol(stan::test::ayt(), y0, t0, ts, nullptr, a)[0](0); + nested.set_zero_all_adjoints(); + yT16.grad(); + check_yT(yT16); + check_y0(y0); + check_t0(t0); + check_ts(ts); + check_a(a); + } + + void test_tol_int_ts() { + std::vector ts = {1, 2}; + + double a = 1.5; + + std::vector> output + = apply_solver_ts_tol(ts, 1e-10, 1e-10, 1e6, a); + + EXPECT_FLOAT_EQ(output[0][0], 0.6649966577); + EXPECT_FLOAT_EQ(output[1][0], 0.09408000537); + } + + void test_tol_t0() { + stan::math::nested_rev_autodiff nested; + + var t0 = 0.0; + + double a = 1.5; + + std::vector> output + = apply_solver_t0_tol(t0, 1e-10, 1e-10, 1e6, a); + + output[0][0].grad(); + + EXPECT_FLOAT_EQ(output[0][0].val(), 0.4165982112); + EXPECT_FLOAT_EQ(t0.adj(), -1.0); + + nested.set_zero_all_adjoints(); + + output[1][0].grad(); + + EXPECT_FLOAT_EQ(output[1][0].val(), 0.66457668563); + EXPECT_FLOAT_EQ(t0.adj(), -1.0); + } + + void test_tol_ts() { + stan::math::nested_rev_autodiff nested; + + std::vector ts = {0.45, 1.1}; + + double a = 1.5; + + std::vector> output + = apply_solver_ts_tol(ts, 1e-10, 1e-10, 1e6, a); + + output[0][0].grad(); + + EXPECT_FLOAT_EQ(output[0][0].val(), 0.4165982112); + EXPECT_FLOAT_EQ(ts[0].adj(), 0.78070695113); + + nested.set_zero_all_adjoints(); + + output[1][0].grad(); + + EXPECT_FLOAT_EQ(output[1][0].val(), 0.66457668563); + EXPECT_FLOAT_EQ(ts[1].adj(), -0.0791208888); + } + + void test_tol_ts_repeat() { + stan::math::nested_rev_autodiff nested; + + std::vector ts = {0.45, 0.45, 1.1, 1.1}; + + double a = 1.5; + + std::vector> output + = apply_solver_ts_tol(ts, 1e-10, 1e-10, 1e6, a); + + EXPECT_EQ(output.size(), ts.size()); + + output[0][0].grad(); + + EXPECT_FLOAT_EQ(output[0][0].val(), 0.4165982112); + EXPECT_FLOAT_EQ(ts[0].adj(), 0.78070695113); + + nested.set_zero_all_adjoints(); + + output[1][0].grad(); + + EXPECT_FLOAT_EQ(output[1][0].val(), 0.4165982112); + EXPECT_FLOAT_EQ(ts[1].adj(), 0.78070695113); + + nested.set_zero_all_adjoints(); + + output[2][0].grad(); + + EXPECT_FLOAT_EQ(output[2][0].val(), 0.66457668563); + EXPECT_FLOAT_EQ(ts[2].adj(), -0.0791208888); + + nested.set_zero_all_adjoints(); + + output[3][0].grad(); + + EXPECT_FLOAT_EQ(output[3][0].val(), 0.66457668563); + EXPECT_FLOAT_EQ(ts[3].adj(), -0.0791208888); + } + + void test_tol_scalar_arg() { + stan::math::nested_rev_autodiff nested; + + std::vector ts = {1.1}; + + var a = 1.5; + + var output = apply_solver_ts_tol(ts, 1e-8, 1e-10, 1e6, a)[0][0]; + + output.grad(); + + EXPECT_FLOAT_EQ(output.val(), 0.66457668563); + EXPECT_FLOAT_EQ(a.adj(), -0.50107310888); + } + + void test_tol_scalar_arg_multi_time() { + stan::math::nested_rev_autodiff nested; + + std::vector ts = {0.45, 1.1}; + + var a = 1.5; + + std::vector> output + = apply_solver_ts_tol(ts, 1e-10, 1e-10, 1e6, a); + + output[0](0).grad(); + + EXPECT_FLOAT_EQ(output[0](0).val(), 0.4165982112); + EXPECT_FLOAT_EQ(a.adj(), -0.04352005542); + + nested.set_zero_all_adjoints(); + + output[1](0).grad(); + + EXPECT_FLOAT_EQ(output[1](0).val(), 0.66457668563); + EXPECT_FLOAT_EQ(a.adj(), -0.50107310888); + } + + void test_tol_std_vector_arg() { + stan::math::nested_rev_autodiff nested; + + std::vector ts = {1.1}; + + std::vector a = {1.5}; + + var output = apply_solver_ts_tol(ts, 1e-8, 1e-10, 1e6, a)[0][0]; + + output.grad(); + + EXPECT_FLOAT_EQ(output.val(), 0.66457668563); + EXPECT_FLOAT_EQ(a[0].adj(), -0.50107310888); + } + + void test_tol_vector_arg() { + stan::math::nested_rev_autodiff nested; + + std::vector ts = {1.1}; + + Eigen::Matrix a(1); + a << 1.5; + + var output = apply_solver_t0_tol(t0, 1e-8, 1e-10, 1e6, a)[1][0]; + + output.grad(); + + EXPECT_FLOAT_EQ(output.val(), 0.66457668563); + EXPECT_FLOAT_EQ(a(0).adj(), -0.50107310888); + } + + void test_tol_row_vector_arg() { + stan::math::nested_rev_autodiff nested; + + std::vector ts = {1.1}; + + Eigen::Matrix a(1); + a << 1.5; + + var output = apply_solver_ts_tol(ts, 1e-8, 1e-10, 1e6, a)[0][0]; + + output.grad(); + + EXPECT_FLOAT_EQ(output.val(), 0.66457668563); + EXPECT_FLOAT_EQ(a(0).adj(), -0.50107310888); + } + + void test_tol_matrix_arg() { + stan::math::nested_rev_autodiff nested; + std::tuple_element_t<1, T> sol; + + std::vector ts = {1.1}; + + Eigen::Matrix a(1, 1); + a << 1.5; + + var output = apply_solver_ts_tol(ts, 1e-8, 1e-10, 1e6, a)[0][0]; + + output.grad(); + + EXPECT_FLOAT_EQ(output.val(), 0.66457668563); + EXPECT_FLOAT_EQ(a(0, 0).adj(), -0.50107310888); + } + + void test_tol_scalar_std_vector_args() { + stan::math::nested_rev_autodiff nested; + std::tuple_element_t<1, T> sol; + + std::vector ts = {1.1}; + + var a0 = 0.75; + std::vector a1 = {0.75}; + + var output = sol(stan::math::ode_closure_adapter(), y0, t0, ts, 1e-8, 1e-10, 1e6, + nullptr, + stan::math::from_lambda([](auto& a, auto& t, auto& y, auto& b, std::ostream* msgs) { + return stan::test::Cos2Arg()(t, y, msgs, a, b); + }, a0), a1)[0][0]; + + output.grad(); + + EXPECT_FLOAT_EQ(output.val(), 0.66457668563); + EXPECT_FLOAT_EQ(a0.adj(), -0.50107310888); + EXPECT_FLOAT_EQ(a1[0].adj(), -0.50107310888); + } + + void test_tol_std_vector_std_vector_args() { + stan::math::nested_rev_autodiff nested; + + std::vector ts = {1.1}; + + var a0 = 1.5; + std::vector a1(1, a0); + std::vector> a2(1, a1); + + var output = apply_solver_ts_tol(ts, 1e-8, 1e-10, 1e6, a2)[0][0]; + + output.grad(); + + EXPECT_FLOAT_EQ(output.val(), 0.66457668563); + EXPECT_FLOAT_EQ(a2[0][0].adj(), -0.50107310888); + } + + void test_tol_std_vector_vector_args() { + stan::math::nested_rev_autodiff nested; + + std::vector ts = {1.1}; + + var a0 = 1.5; + Eigen::Matrix a1(1); + a1 << a0; + std::vector> a2(1, a1); + + var output = apply_solver_ts_tol(ts, 1e-8, 1e-10, 1e6, a2)[0][0]; + + output.grad(); + + EXPECT_FLOAT_EQ(output.val(), 0.66457668563); + EXPECT_FLOAT_EQ(a2[0](0).adj(), -0.50107310888); + } + + void test_tol_std_vector_row_vector_args() { + stan::math::nested_rev_autodiff nested; + + std::vector ts = {1.1}; + + var a0 = 1.5; + Eigen::Matrix a1(1); + a1 << a0; + std::vector> a2(1, a1); + + var output = apply_solver_ts_tol(ts, 1e-8, 1e-10, 1e6, a2)[0][0]; + + output.grad(); + + EXPECT_FLOAT_EQ(output.val(), 0.66457668563); + EXPECT_FLOAT_EQ(a2[0](0).adj(), -0.50107310888); + } + + void test_tol_std_vector_matrix_args() { + stan::math::nested_rev_autodiff nested; + + std::vector ts = {1.1}; + + var a0 = 1.5; + Eigen::Matrix a1(1, 1); + a1 << a0; + std::vector> a2(1, a1); + + var output = apply_solver_ts_tol(ts, 1e-8, 1e-10, 1e6, a2)[0][0]; + + output.grad(); + + EXPECT_FLOAT_EQ(output.val(), 0.66457668563); + EXPECT_FLOAT_EQ(a2[0](0).adj(), -0.50107310888); + } +}; + +#endif diff --git a/test/unit/math/rev/functor/test_fixture_ode_cos_scalar.hpp b/test/unit/math/rev/functor/test_fixture_ode_cos_scalar.hpp index dce22697732..d0241fae6b9 100644 --- a/test/unit/math/rev/functor/test_fixture_ode_cos_scalar.hpp +++ b/test/unit/math/rev/functor/test_fixture_ode_cos_scalar.hpp @@ -1285,27 +1285,6 @@ struct cos_arg_test : public cos_arg_ode_base, EXPECT_FLOAT_EQ(a2[0](0).adj(), -0.50107310888); } - void test_closure() { - stan::math::nested_rev_autodiff nested; - std::tuple_element_t<0, T> sol; - var a0 = 0.75; - std::vector a1 = {0.75}; - - auto f = stan::math::from_lambda( - [](const auto& b, const auto& t, const auto& y, const auto& a, - std::ostream* msgs) { - return stan::test::Cos2Arg()(t, y, msgs, a, b); - }, - a1); - var output = sol(f, y0, t0, ts, nullptr, a0)[1][0]; - - output.grad(); - - EXPECT_NEAR(output.val(), 0.66457668563, 1e-5); - EXPECT_NEAR(a0.adj(), -0.50107310888, 1e-5); - EXPECT_NEAR(a1[0].adj(), -0.50107310888, 1e-5); - } - void test_tol_arg_combos_test() { stan::math::nested_rev_autodiff nested; std::tuple_element_t<1, T> sol; From c0435118f4059702583e85cf28de439470c963e7 Mon Sep 17 00:00:00 2001 From: Stan Jenkins Date: Tue, 15 Jun 2021 16:54:53 +0000 Subject: [PATCH 10/25] [Jenkins] auto-formatting by clang-format version 6.0.0-1ubuntu2~16.04.1 (tags/RELEASE_600/final) --- stan/math/prim/functor/closure_adapter.hpp | 4 +- stan/math/prim/functor/integrate_ode_rk45.hpp | 7 +- ...grate_ode_std_vector_interface_adapter.hpp | 19 ++- stan/math/rev/functor/integrate_ode_adams.hpp | 7 +- stan/math/rev/functor/integrate_ode_bdf.hpp | 7 +- .../rev/functor/test_fixture_ode_closure.hpp | 134 +++++++++++------- 6 files changed, 108 insertions(+), 70 deletions(-) diff --git a/stan/math/prim/functor/closure_adapter.hpp b/stan/math/prim/functor/closure_adapter.hpp index 2d73223ee7f..34e9b6ac053 100644 --- a/stan/math/prim/functor/closure_adapter.hpp +++ b/stan/math/prim/functor/closure_adapter.hpp @@ -59,7 +59,9 @@ struct one_arg_closure { size_t count_vars__() const { return count_vars(s_); } auto value_of__() const { return ValueOf__(f_, eval(value_of(s_))); } auto copy_of__() const { return CopyOf__(f_, s_); } - auto deep_copy_vars__() const { return CopyOf__(f_, eval(deep_copy_vars(s_))); } + auto deep_copy_vars__() const { + return CopyOf__(f_, eval(deep_copy_vars(s_))); + } void zero_adjoints__() { zero_adjoints(s_); } double* accumulate_adjoints__(double* dest) const { return accumulate_adjoints(dest, s_); diff --git a/stan/math/prim/functor/integrate_ode_rk45.hpp b/stan/math/prim/functor/integrate_ode_rk45.hpp index 31b691d1cf3..22b18e70dfd 100644 --- a/stan/math/prim/functor/integrate_ode_rk45.hpp +++ b/stan/math/prim/functor/integrate_ode_rk45.hpp @@ -22,9 +22,10 @@ inline auto integrate_ode_rk45( std::ostream* msgs = nullptr, double relative_tolerance = 1e-6, double absolute_tolerance = 1e-6, int max_num_steps = 1e6) { internal::integrate_ode_std_vector_interface_adapter f_adapted(f); - auto y = ode_rk45_tol_impl("integrate_ode_rk45", ode_closure_adapter(), to_vector(y0), t0, - ts, relative_tolerance, absolute_tolerance, - max_num_steps, msgs, f_adapted, theta, x, x_int); + auto y = ode_rk45_tol_impl("integrate_ode_rk45", ode_closure_adapter(), + to_vector(y0), t0, ts, relative_tolerance, + absolute_tolerance, max_num_steps, msgs, f_adapted, + theta, x, x_int); std::vector>> y_converted; diff --git a/stan/math/prim/functor/integrate_ode_std_vector_interface_adapter.hpp b/stan/math/prim/functor/integrate_ode_std_vector_interface_adapter.hpp index 377de796dce..30dc45adb8b 100644 --- a/stan/math/prim/functor/integrate_ode_std_vector_interface_adapter.hpp +++ b/stan/math/prim/functor/integrate_ode_std_vector_interface_adapter.hpp @@ -26,10 +26,10 @@ struct integrate_ode_std_vector_interface_adapter_impl; template struct integrate_ode_std_vector_interface_adapter_impl { using captured_scalar_t__ = double; - using ValueOf__ = integrate_ode_std_vector_interface_adapter_impl< - false, F, false>; - using CopyOf__ = integrate_ode_std_vector_interface_adapter_impl< - false, F, false>; + using ValueOf__ + = integrate_ode_std_vector_interface_adapter_impl; + using CopyOf__ + = integrate_ode_std_vector_interface_adapter_impl; const F& f_; explicit integrate_ode_std_vector_interface_adapter_impl(const F& f) : f_(f) {} @@ -37,18 +37,15 @@ struct integrate_ode_std_vector_interface_adapter_impl { template auto operator()(std::ostream* msgs, const T0& t, const Eigen::Matrix& y, - const std::vector& theta, - const std::vector& x, + const std::vector& theta, const std::vector& x, const std::vector& x_int) const { return to_vector(f_(t, to_array_1d(y), theta, x, x_int, msgs)); } - size_t count_vars__() const { } + size_t count_vars__() const {} auto value_of__() const { return ValueOf__(f_); } auto deep_copy_vars__() const { return CopyOf__(f_); } - void zero_adjoints__() { } - double* accumulate_adjoints__(double* dest) const { - return dest; - } + void zero_adjoints__() {} + double* accumulate_adjoints__(double* dest) const { return dest; } template Vari** save_varis__(Vari** dest) const { return dest; diff --git a/stan/math/rev/functor/integrate_ode_adams.hpp b/stan/math/rev/functor/integrate_ode_adams.hpp index ca48b2cbfe2..8d1c2b18c03 100644 --- a/stan/math/rev/functor/integrate_ode_adams.hpp +++ b/stan/math/rev/functor/integrate_ode_adams.hpp @@ -25,9 +25,10 @@ integrate_ode_adams(const F& f, const std::vector& y0, const T_t0& t0, double absolute_tolerance = 1e-10, long int max_num_steps = 1e8) { // NOLINT(runtime/int) internal::integrate_ode_std_vector_interface_adapter f_adapted(f); - auto y = ode_adams_tol_impl("integrate_ode_adams", ode_closure_adapter(), to_vector(y0), - t0, ts, relative_tolerance, absolute_tolerance, - max_num_steps, msgs, f_adapted, theta, x, x_int); + auto y = ode_adams_tol_impl("integrate_ode_adams", ode_closure_adapter(), + to_vector(y0), t0, ts, relative_tolerance, + absolute_tolerance, max_num_steps, msgs, + f_adapted, theta, x, x_int); std::vector>> y_converted; diff --git a/stan/math/rev/functor/integrate_ode_bdf.hpp b/stan/math/rev/functor/integrate_ode_bdf.hpp index a7f113e201e..f25601aa6c0 100644 --- a/stan/math/rev/functor/integrate_ode_bdf.hpp +++ b/stan/math/rev/functor/integrate_ode_bdf.hpp @@ -25,9 +25,10 @@ integrate_ode_bdf(const F& f, const std::vector& y0, const T_t0& t0, double absolute_tolerance = 1e-10, long int max_num_steps = 1e8) { // NOLINT(runtime/int) internal::integrate_ode_std_vector_interface_adapter f_adapted(f); - auto y = ode_bdf_tol_impl("integrate_ode_bdf", ode_closure_adapter(), to_vector(y0), t0, - ts, relative_tolerance, absolute_tolerance, - max_num_steps, msgs, f_adapted, theta, x, x_int); + auto y = ode_bdf_tol_impl("integrate_ode_bdf", ode_closure_adapter(), + to_vector(y0), t0, ts, relative_tolerance, + absolute_tolerance, max_num_steps, msgs, f_adapted, + theta, x, x_int); std::vector>> y_converted; diff --git a/test/unit/math/rev/functor/test_fixture_ode_closure.hpp b/test/unit/math/rev/functor/test_fixture_ode_closure.hpp index 8dac5742d28..9b2a70b040a 100644 --- a/test/unit/math/rev/functor/test_fixture_ode_closure.hpp +++ b/test/unit/math/rev/functor/test_fixture_ode_closure.hpp @@ -50,108 +50,140 @@ struct closure_test : public closure_ode_base, auto apply_solver() { std::tuple_element_t<0, T> sol; return sol(stan::math::ode_closure_adapter(), y0, t0, ts, nullptr, - stan::math::from_lambda([](auto& a_, auto& t, auto& y, std::ostream* msgs) { - return stan::test::CosArg1()(t, y, msgs, a_); - }, a)); + stan::math::from_lambda( + [](auto& a_, auto& t, auto& y, std::ostream* msgs) { + return stan::test::CosArg1()(t, y, msgs, a_); + }, + a)); } template auto apply_solver(Eigen::Matrix& init, std::vector& va) { std::tuple_element_t<0, T> sol; return sol(stan::math::ode_closure_adapter(), init, t0, ts, nullptr, - stan::math::from_lambda([](auto& a, auto& t, auto& y, std::ostream* msgs) { - return stan::test::CosArg1()(t, y, msgs, a); - }, va)); + stan::math::from_lambda( + [](auto& a, auto& t, auto& y, std::ostream* msgs) { + return stan::test::CosArg1()(t, y, msgs, a); + }, + va)); } template auto apply_solver_ts(const std::vector& ts_) { std::tuple_element_t<0, T> sol; return sol(stan::math::ode_closure_adapter(), y0, t0, ts_, nullptr, - stan::math::from_lambda([](auto& a_, auto& t, auto& y, std::ostream* msgs) { - return stan::test::CosArg1()(t, y, msgs, a_); - }, a)); + stan::math::from_lambda( + [](auto& a_, auto& t, auto& y, std::ostream* msgs) { + return stan::test::CosArg1()(t, y, msgs, a_); + }, + a)); } template auto apply_solver_ts(const std::vector& ts_, const a_type& arg) { std::tuple_element_t<0, T> sol; return sol(stan::math::ode_closure_adapter(), y0, t0, ts_, nullptr, - stan::math::from_lambda([](auto& a_, auto& t, auto& y, std::ostream* msgs) { - return stan::test::CosArg1()(t, y, msgs, a_); - }, arg)); + stan::math::from_lambda( + [](auto& a_, auto& t, auto& y, std::ostream* msgs) { + return stan::test::CosArg1()(t, y, msgs, a_); + }, + arg)); } template - auto apply_solver_ts_tol(const std::vector& ts_, double rtol, double atol, int max_num_steps, const a_type& a_) { + auto apply_solver_ts_tol(const std::vector& ts_, double rtol, double atol, + int max_num_steps, const a_type& a_) { std::tuple_element_t<1, T> sol; - return sol(stan::math::ode_closure_adapter(), y0, t0, ts_, rtol, atol, max_num_steps, nullptr, - stan::math::from_lambda([](auto& a, auto& t, auto& y, std::ostream* msgs) { - return stan::test::CosArg1()(t, y, msgs, a); - }, a_)); + return sol(stan::math::ode_closure_adapter(), y0, t0, ts_, rtol, atol, + max_num_steps, nullptr, + stan::math::from_lambda( + [](auto& a, auto& t, auto& y, std::ostream* msgs) { + return stan::test::CosArg1()(t, y, msgs, a); + }, + a_)); } template auto apply_solver_t0(const T0& t0_) { std::tuple_element_t<0, T> sol; return sol(stan::math::ode_closure_adapter(), y0, t0_, ts, nullptr, - stan::math::from_lambda([](auto& a_, auto& t, auto& y, std::ostream* msgs) { - return stan::test::CosArg1()(t, y, msgs, a_); - }, a)); + stan::math::from_lambda( + [](auto& a_, auto& t, auto& y, std::ostream* msgs) { + return stan::test::CosArg1()(t, y, msgs, a_); + }, + a)); } template - auto apply_solver_t0_tol(const T0& t0_, double rtol, double atol, int max_num_steps, const a_type& a_) { + auto apply_solver_t0_tol(const T0& t0_, double rtol, double atol, + int max_num_steps, const a_type& a_) { std::tuple_element_t<1, T> sol; - return sol(stan::math::ode_closure_adapter(), y0, t0_, ts, rtol, atol, max_num_steps, nullptr, - stan::math::from_lambda([](auto& a, auto& t, auto& y, std::ostream* msgs) { - return stan::test::CosArg1()(t, y, msgs, a); - }, a_)); + return sol(stan::math::ode_closure_adapter(), y0, t0_, ts, rtol, atol, + max_num_steps, nullptr, + stan::math::from_lambda( + [](auto& a, auto& t, auto& y, std::ostream* msgs) { + return stan::test::CosArg1()(t, y, msgs, a); + }, + a_)); } auto apply_solver_tol() { std::tuple_element_t<1, T> sol; - return sol(stan::math::ode_closure_adapter(), y0, t0, ts, rtol, atol, max_num_step, - nullptr, - stan::math::from_lambda([](auto& a_, auto& t, auto& y, std::ostream* msgs) { - return stan::test::CosArg1()(t, y, msgs, a_); - }, a)); + return sol(stan::math::ode_closure_adapter(), y0, t0, ts, rtol, atol, + max_num_step, nullptr, + stan::math::from_lambda( + [](auto& a_, auto& t, auto& y, std::ostream* msgs) { + return stan::test::CosArg1()(t, y, msgs, a_); + }, + a)); } template auto apply_solver_arg(a_type const& a_) { std::tuple_element_t<0, T> sol; return sol(stan::math::ode_closure_adapter(), y0, t0, ts, nullptr, - stan::math::from_lambda([](auto& a, auto& t, auto& y, std::ostream* msgs) { - return stan::test::CosArg1()(t, y, msgs, a); - }, a_)); + stan::math::from_lambda( + [](auto& a, auto& t, auto& y, std::ostream* msgs) { + return stan::test::CosArg1()(t, y, msgs, a); + }, + a_)); } template auto apply_solver_arg_tol(a_type const& a_) { std::tuple_element_t<1, T> sol; - return sol(stan::math::ode_closure_adapter(), y0, t0, ts, rtol, atol, max_num_step, nullptr, - stan::math::from_lambda([](auto& a, auto& t, auto& y, std::ostream* msgs) { - return stan::test::CosArg1()(t, y, msgs, a); - }, a_)); + return sol(stan::math::ode_closure_adapter(), y0, t0, ts, rtol, atol, + max_num_step, nullptr, + stan::math::from_lambda( + [](auto& a, auto& t, auto& y, std::ostream* msgs) { + return stan::test::CosArg1()(t, y, msgs, a); + }, + a_)); } template auto apply_solver_arg(a_type const& a_, b_type const& b_) { std::tuple_element_t<0, T> sol; return sol(stan::math::ode_closure_adapter(), y0, t0, ts, nullptr, - stan::math::from_lambda([](auto& a, auto& t, auto& y, auto& b, std::ostream* msgs) { - return stan::test::Cos2Arg()(t, y, msgs, a, b); - }, a_), b_); + stan::math::from_lambda( + [](auto& a, auto& t, auto& y, auto& b, std::ostream* msgs) { + return stan::test::Cos2Arg()(t, y, msgs, a, b); + }, + a_), + b_); } template auto apply_solver_arg_tol(a_type const& a_, b_type const& b_) { std::tuple_element_t<1, T> sol; - return sol(stan::math::ode_closure_adapter(), y0, t0, ts, rtol, atol, max_num_step, nullptr, - stan::math::from_lambda([](auto& a, auto& t, auto& y, auto& b, std::ostream* msgs) { - return stan::test::Cos2Arg()(t, y, msgs, a, b); - }, a_), b_); + return sol(stan::math::ode_closure_adapter(), y0, t0, ts, rtol, atol, + max_num_step, nullptr, + stan::math::from_lambda( + [](auto& a, auto& t, auto& y, auto& b, std::ostream* msgs) { + return stan::test::Cos2Arg()(t, y, msgs, a, b); + }, + a_), + b_); } void test_y0_error() { @@ -1074,11 +1106,15 @@ struct closure_test : public closure_ode_base, var a0 = 0.75; std::vector a1 = {0.75}; - var output = sol(stan::math::ode_closure_adapter(), y0, t0, ts, 1e-8, 1e-10, 1e6, - nullptr, - stan::math::from_lambda([](auto& a, auto& t, auto& y, auto& b, std::ostream* msgs) { - return stan::test::Cos2Arg()(t, y, msgs, a, b); - }, a0), a1)[0][0]; + var output + = sol(stan::math::ode_closure_adapter(), y0, t0, ts, 1e-8, 1e-10, 1e6, + nullptr, + stan::math::from_lambda( + [](auto& a, auto& t, auto& y, auto& b, std::ostream* msgs) { + return stan::test::Cos2Arg()(t, y, msgs, a, b); + }, + a0), + a1)[0][0]; output.grad(); From e72e6f478598de94b22614b80e5263aece53d8fd Mon Sep 17 00:00:00 2001 From: Niko Huurre Date: Tue, 15 Jun 2021 20:31:45 +0300 Subject: [PATCH 11/25] fix cpplint --- stan/math/prim/functor/integrate_1d_adapter.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/stan/math/prim/functor/integrate_1d_adapter.hpp b/stan/math/prim/functor/integrate_1d_adapter.hpp index 60a7ebb3fc4..1bf065f809f 100644 --- a/stan/math/prim/functor/integrate_1d_adapter.hpp +++ b/stan/math/prim/functor/integrate_1d_adapter.hpp @@ -29,7 +29,7 @@ struct integrate_1d_adapter { * Call a closure object from integrate_1d */ struct integrate_1d_closure_adapter { - explicit integrate_1d_closure_adapter() {} + integrate_1d_closure_adapter() {} template auto operator()(const T_a& x, const T_b& xc, std::ostream* msgs, const F& f, From e12006486a0e7e3f135beca14f1a76e1c882a91f Mon Sep 17 00:00:00 2001 From: Niko Huurre Date: Tue, 15 Jun 2021 21:03:12 +0300 Subject: [PATCH 12/25] fix test --- .../functor/integrate_ode_std_vector_interface_adapter_test.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/unit/math/prim/functor/integrate_ode_std_vector_interface_adapter_test.cpp b/test/unit/math/prim/functor/integrate_ode_std_vector_interface_adapter_test.cpp index 1d7f484a356..2e0f25b31ad 100644 --- a/test/unit/math/prim/functor/integrate_ode_std_vector_interface_adapter_test.cpp +++ b/test/unit/math/prim/functor/integrate_ode_std_vector_interface_adapter_test.cpp @@ -21,7 +21,7 @@ TEST(StanMath, check_values) { Eigen::VectorXd out1 = stan::math::to_vector(harm_osc(t, y, theta, x, x_int, nullptr)); Eigen::VectorXd out2 - = harm_osc_adapted(t, stan::math::to_vector(y), nullptr, theta, x, x_int); + = harm_osc_adapted(nullptr, t, stan::math::to_vector(y), theta, x, x_int); EXPECT_MATRIX_FLOAT_EQ(out1, out2); } From 1245fa6f11960a4aa0ca66c599e8ec3567d03dbb Mon Sep 17 00:00:00 2001 From: Niko Huurre Date: Tue, 15 Jun 2021 22:24:50 +0300 Subject: [PATCH 13/25] fix promote_scalar_type --- stan/math/prim/functor/closure_adapter.hpp | 4 ++-- stan/math/prim/meta/promote_scalar_type.hpp | 18 ++++++++++++++++++ 2 files changed, 20 insertions(+), 2 deletions(-) diff --git a/stan/math/prim/functor/closure_adapter.hpp b/stan/math/prim/functor/closure_adapter.hpp index 34e9b6ac053..6d0aa2d8602 100644 --- a/stan/math/prim/functor/closure_adapter.hpp +++ b/stan/math/prim/functor/closure_adapter.hpp @@ -166,8 +166,8 @@ struct empty_closure_lp { */ struct ode_closure_adapter { template - auto operator()(const T0& t, const Eigen::Matrix& y, - std::ostream* msgs, const F& f, Args... args) const { + auto operator()(const T0& t, const T1& y, std::ostream* msgs, const F& f, + Args... args) const { return f(msgs, t, y, args...); } }; diff --git a/stan/math/prim/meta/promote_scalar_type.hpp b/stan/math/prim/meta/promote_scalar_type.hpp index 4c03569903a..2f2a11382d5 100644 --- a/stan/math/prim/meta/promote_scalar_type.hpp +++ b/stan/math/prim/meta/promote_scalar_type.hpp @@ -4,6 +4,7 @@ #include #include #include +#include #include namespace stan { @@ -93,6 +94,23 @@ struct promote_scalar_type> { S::RowsAtCompileTime, S::ColsAtCompileTime>>::type; }; +/** + * Template metaprogram to calculate a type for a closure whose + * underlying scalar is converted from the second template + * parameter type to the first. + * + * @tparam T result scalar type. + * @tparam S input closure type + */ +template +struct promote_scalar_type> { + /** + * The promoted type. + */ + using type = typename std::conditional::value, F, + typename F::ValueOf__>::type; +}; + template using promote_scalar_t = typename promote_scalar_type, std::decay_t>::type; From 9c2581759423fd61278a1f95467f3c45f7873a6d Mon Sep 17 00:00:00 2001 From: Niko Huurre Date: Tue, 15 Jun 2021 23:19:22 +0300 Subject: [PATCH 14/25] fix --- .../prim/functor/integrate_ode_std_vector_interface_adapter.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/stan/math/prim/functor/integrate_ode_std_vector_interface_adapter.hpp b/stan/math/prim/functor/integrate_ode_std_vector_interface_adapter.hpp index 30dc45adb8b..04e20757020 100644 --- a/stan/math/prim/functor/integrate_ode_std_vector_interface_adapter.hpp +++ b/stan/math/prim/functor/integrate_ode_std_vector_interface_adapter.hpp @@ -41,7 +41,7 @@ struct integrate_ode_std_vector_interface_adapter_impl { const std::vector& x_int) const { return to_vector(f_(t, to_array_1d(y), theta, x, x_int, msgs)); } - size_t count_vars__() const {} + size_t count_vars__() const { return 0; } auto value_of__() const { return ValueOf__(f_); } auto deep_copy_vars__() const { return CopyOf__(f_); } void zero_adjoints__() {} From 6990768136daeb463f89a077e34ddc80e770480b Mon Sep 17 00:00:00 2001 From: Niko Huurre Date: Mon, 19 Jul 2021 14:42:45 +0300 Subject: [PATCH 15/25] generalize base_closure --- stan/math/prim/functor/closure_adapter.hpp | 90 +++++++++------------- 1 file changed, 35 insertions(+), 55 deletions(-) diff --git a/stan/math/prim/functor/closure_adapter.hpp b/stan/math/prim/functor/closure_adapter.hpp index 6d0aa2d8602..5ad2cbeab86 100644 --- a/stan/math/prim/functor/closure_adapter.hpp +++ b/stan/math/prim/functor/closure_adapter.hpp @@ -4,6 +4,7 @@ #include #include #include +#include #include namespace stan { @@ -11,64 +12,51 @@ namespace math { namespace internal { /** - * A closure that wraps a C++ lambda. + * A closure that wraps a C++ lambda and captures values. */ -template -struct empty_closure { - using captured_scalar_t__ = double; - using ValueOf__ = empty_closure; - using CopyOf__ = empty_closure; +template +struct base_closure { + using captured_scalar_t__ = return_type_t; + using ValueOf__ + = base_closure())))...>; + using CopyOf__ = base_closure; F f_; + std::tuple...> captures_; - explicit empty_closure(const F& f) : f_(f) {} + explicit base_closure(const F& f, const Ts&... args) + : f_(f), captures_(args...) {} template auto operator()(std::ostream* msgs, Args... args) const { - return f_(args..., msgs); + return apply( + [this, msgs, args...](auto... s) { return f_(s..., args..., msgs); }, + captures_); } - size_t count_vars__() const { return 0; } - auto value_of__() const { return ValueOf__(f_); } - auto copy_of__() const { return CopyOf__(f_); } - auto deep_copy_vars__() const { return CopyOf__(f_); } - void zero_adjoints__() const {} - double* accumulate_adjoints__(double* dest) const { return dest; } - template - Vari** save_varis(Vari** dest) const { - return dest; + size_t count_vars__() const { + return apply([this](auto... s) { return count_vars(s...); }, captures_); } -}; - -/** - * A closure that holds one autodiffable capture. - */ -template -struct one_arg_closure { - using captured_scalar_t__ = return_type_t; - using ValueOf__ - = one_arg_closure())))>; - using CopyOf__ = one_arg_closure; - F f_; - capture_type_t s_; - - explicit one_arg_closure(const F& f, const T& s) : f_(f), s_(s) {} - - template - auto operator()(std::ostream* msgs, Args... args) const { - return f_(s_, args..., msgs); + auto value_of__() const { + return apply( + [this](auto... s) { return ValueOf__(f_, eval(value_of(s))...); }, + captures_); } - size_t count_vars__() const { return count_vars(s_); } - auto value_of__() const { return ValueOf__(f_, eval(value_of(s_))); } - auto copy_of__() const { return CopyOf__(f_, s_); } auto deep_copy_vars__() const { - return CopyOf__(f_, eval(deep_copy_vars(s_))); + return apply( + [this](auto... s) { return CopyOf__(f_, eval(deep_copy_vars(s))...); }, + captures_); + } + void zero_adjoints__() { + apply([](auto... s) { zero_adjoints(s...); }, captures_); } - void zero_adjoints__() { zero_adjoints(s_); } double* accumulate_adjoints__(double* dest) const { - return accumulate_adjoints(dest, s_); + return apply([dest](auto... s) { return accumulate_adjoints(dest, s...); }, + captures_); } template Vari** save_varis__(Vari** dest) const { - return save_varis(dest, s_); + return apply([dest](auto... s) { return save_varis(dest, s...); }, + captures_); } }; @@ -173,19 +161,11 @@ struct ode_closure_adapter { }; /** - * Create a closure object from a callable. - */ -template -auto from_lambda(const F& f) { - return internal::empty_closure(f); -} - -/** - * Create a closure that captures a single argument. + * Create a closure from a C++ lambda and captures. */ -template -auto from_lambda(const F& f, const T& a) { - return internal::one_arg_closure(f, a); +template +auto from_lambda(const F& f, const Ts&... a) { + return internal::base_closure(f, a...); } /** From 2e3180a73f014abb729aac9349b6be416ece679d Mon Sep 17 00:00:00 2001 From: Niko Huurre Date: Tue, 27 Jul 2021 19:54:54 +0300 Subject: [PATCH 16/25] non-empty suffix closures --- stan/math/prim/functor/closure_adapter.hpp | 245 ++++++++++++--------- 1 file changed, 140 insertions(+), 105 deletions(-) diff --git a/stan/math/prim/functor/closure_adapter.hpp b/stan/math/prim/functor/closure_adapter.hpp index 5ad2cbeab86..58e65755118 100644 --- a/stan/math/prim/functor/closure_adapter.hpp +++ b/stan/math/prim/functor/closure_adapter.hpp @@ -18,8 +18,7 @@ template struct base_closure { using captured_scalar_t__ = return_type_t; using ValueOf__ - = base_closure())))...>; + = base_closure())))...>; using CopyOf__ = base_closure; F f_; std::tuple...> captures_; @@ -28,10 +27,10 @@ struct base_closure { : f_(f), captures_(args...) {} template - auto operator()(std::ostream* msgs, Args... args) const { - return apply( - [this, msgs, args...](auto... s) { return f_(s..., args..., msgs); }, - captures_); + auto operator()(std::ostream* msgs, const Args&... args) const { + return apply([this, msgs, &args...]( + const auto&... s) { return f_(s..., args..., msgs); }, + captures_); } size_t count_vars__() const { return apply([this](auto... s) { return count_vars(s...); }, captures_); @@ -63,86 +62,159 @@ struct base_closure { /** * A closure that takes rng argument. */ -template -struct empty_closure_rng { +template +struct closure_rng { using captured_scalar_t__ = double; - using ValueOf__ = empty_closure_rng; - using CopyOf__ = empty_closure_rng; + using ValueOf__ = closure_rng; + using CopyOf__ = closure_rng; F f_; + std::tuple...> captures_; - explicit empty_closure_rng(const F& f) : f_(f) {} + explicit closure_rng(const F& f, const Ts&... args) + : f_(f), captures_(args...) {} template - auto operator()(const Rng& rng, std::ostream* msgs, Args... args) const { - return f_(args..., rng, msgs); - } - size_t count_vars__() const { return 0; } - auto value_of__() const { return ValueOf__(f_); } - auto copy_of__() const { return CopyOf__(f_); } - auto deep_copy_vars__() const { return CopyOf__(f_); } - void zero_adjoints__() const {} - double* accumulate_adjoints__(double* dest) const { return dest; } + auto operator()(Rng& rng, std::ostream* msgs, const Args&... args) const { + return apply([this, &rng, msgs, &args...]( + const auto&... s) { return f_(s..., args..., rng, msgs); }, + captures_); + } + + size_t count_vars__() const { + return apply([this](auto... s) { return count_vars(s...); }, captures_); + } + auto value_of__() const { + return apply( + [this](auto... s) { return ValueOf__(f_, eval(value_of(s))...); }, + captures_); + } + auto deep_copy_vars__() const { + return apply( + [this](auto... s) { return CopyOf__(f_, eval(deep_copy_vars(s))...); }, + captures_); + } + void zero_adjoints__() { + apply([](auto... s) { zero_adjoints(s...); }, captures_); + } + double* accumulate_adjoints__(double* dest) const { + return apply([dest](auto... s) { return accumulate_adjoints(dest, s...); }, + captures_); + } template - Vari** save_varis(Vari** dest) const { - return dest; + Vari** save_varis__(Vari** dest) const { + return apply([dest](auto... s) { return save_varis(dest, s...); }, + captures_); } }; /** * A closure that can be called with `propto` template argument. */ -template -struct empty_closure_lpdf { - using captured_scalar_t__ = double; - using ValueOf__ = empty_closure_lpdf; - using CopyOf__ = empty_closure_lpdf; +template +struct closure_lpdf { + using captured_scalar_t__ = return_type_t; + using ValueOf__ = closure_lpdf; + using CopyOf__ = closure_lpdf; F f_; + std::tuple...> captures_; - explicit empty_closure_lpdf(const F& f) : f_(f) {} + explicit closure_lpdf(const F& f, const Ts&... args) + : f_(f), captures_(args...) {} + + template + auto with_propto() { + return apply( + [this](const auto&... args) { + return closure_lpdf < Propto && propto, true, F, + Ts... > (f_, args...); + }, + captures_); + } template - auto operator()(std::ostream* msgs, Args... args) const { - return f_.template operator()(args..., msgs); - } - size_t count_vars__() const { return 0; } - auto value_of__() const { return ValueOf__(f_); } - auto copy_of__() const { return CopyOf__(f_); } - auto deep_copy_vars__() const { return CopyOf__(f_); } - void zero_adjoints__() const {} - double* accumulate_adjoints__(double* dest) const { return dest; } + auto operator()(std::ostream* msgs, const Args&... args) const { + return apply( + [this, msgs, &args...](const auto&... s) { + return f_.template operator()(s..., args..., msgs); + }, + captures_); + } + size_t count_vars__() const { + return apply([this](auto... s) { return count_vars(s...); }, captures_); + } + auto value_of__() const { + return apply( + [this](auto... s) { return ValueOf__(f_, eval(value_of(s))...); }, + captures_); + } + auto deep_copy_vars__() const { + return apply( + [this](auto... s) { return CopyOf__(f_, eval(deep_copy_vars(s))...); }, + captures_); + } + void zero_adjoints__() { + apply([](auto... s) { zero_adjoints(s...); }, captures_); + } + double* accumulate_adjoints__(double* dest) const { + return apply([dest](auto... s) { return accumulate_adjoints(dest, s...); }, + captures_); + } template - Vari** save_varis(Vari** dest) const { - return dest; + Vari** save_varis__(Vari** dest) const { + return apply([dest](auto... s) { return save_varis(dest, s...); }, + captures_); } }; /** * A closure that accesses logprob accumulator. */ -template -struct empty_closure_lp { - using captured_scalar_t__ = double; - using ValueOf__ = empty_closure_lp; - using CopyOf__ = empty_closure_lp; - static const size_t vars_count__ = 0; +template +struct closure_lp { + using captured_scalar_t__ = return_type_t; + using ValueOf__ = closure_lp; + using CopyOf__ = closure_lp; F f_; + std::tuple...> captures_; - explicit empty_closure_lp(const F& f) : f_(f) {} + explicit closure_lp(const F& f, const Ts&... args) + : f_(f), captures_(args...) {} - template - auto operator()(T_lp_accum& lp, T_lp& lp_accum, std::ostream* msgs, - Args... args) const { - return f_(args..., lp, lp_accum, msgs); - } - size_t count_vars__() const { return 0; } - auto value_of__() const { return ValueOf__(f_); } - auto copy_of__() const { return CopyOf__(f_); } - auto deep_copy_vars__() const { return CopyOf__(f_); } - void zero_adjoints__() const {} - double* accumulate_adjoints__(double* dest) const { return dest; } + template + auto operator()(T_lp& lp, T_lp_accum& lp_accum, std::ostream* msgs, + const Args&... args) const { + return apply( + [this, &lp, &lp_accum, msgs, &args...](const auto&... s) { + return f_.template operator()(s..., args..., lp, lp_accum, + msgs); + }, + captures_); + } + size_t count_vars__() const { + return apply([this](auto... s) { return count_vars(s...); }, captures_); + } + auto value_of__() const { + return apply( + [this](auto... s) { return ValueOf__(f_, eval(value_of(s))...); }, + captures_); + } + auto deep_copy_vars__() const { + return apply( + [this](auto... s) { return CopyOf__(f_, eval(deep_copy_vars(s))...); }, + captures_); + } + void zero_adjoints__() { + apply([](auto... s) { zero_adjoints(s...); }, captures_); + } + double* accumulate_adjoints__(double* dest) const { + return apply([dest](auto... s) { return accumulate_adjoints(dest, s...); }, + captures_); + } template - Vari** save_varis(Vari** dest) const { - return dest; + Vari** save_varis__(Vari** dest) const { + return apply([dest](auto... s) { return save_varis(dest, s...); }, + captures_); } }; @@ -171,64 +243,27 @@ auto from_lambda(const F& f, const Ts&... a) { /** * Create a closure from an rng functor. */ -template -auto rng_from_lambda(const F& f) { - return internal::empty_closure_rng(f); +template +auto rng_from_lambda(const F& f, const Ts&... a) { + return internal::closure_rng(f, a...); } /** * Create a closure from an lpdf functor. */ -template -auto lpdf_from_lambda(const F& f) { - return internal::empty_closure_lpdf(f); +template +auto lpdf_from_lambda(const F& f, const Ts&... a) { + return internal::closure_lpdf(f, a...); } /** * Create a closure from a functor that needs access to logprob accumulator. */ -template -auto lp_from_lambda(const F& f) { - return internal::empty_closure_lp(f); +template +auto lp_from_lambda(const F& f, const Ts&... args) { + return internal::closure_lp(f, args...); } -/** - * A wrapper that sets propto template argument when calling the inner closure. - */ -template -struct lpdf_wrapper { - using captured_scalar_t__ = return_type_t; - using ValueOf__ - = lpdf_wrapper().value_of__()), false>; - using CopyOf__ - = lpdf_wrapper().copy_of__()), false>; - capture_type_t f_; - - explicit lpdf_wrapper(const F& f) : f_(f) {} - - template - auto with_propto() { - return lpdf_wrapper < Propto && propto, F, true > (f_); - } - - template - auto operator()(Args... args) const { - return f_.template operator() < Propto && propto > (args...); - } - size_t count_vars__() const { return count_vars(f_); } - auto value_of__() const { return ValueOf__(value_of(f_)); } - auto deep_copy_vars__() const { return CopyOf__(deep_copy_vars(f_)); } - auto copy_of__() const { return CopyOf__(f_.copy_of__()); } - void zero_adjoints__() { zero_adjoints(f_); } - double* accumulate_adjoints__(double* dest) const { - return accumulate_adjoints(dest, f_); - } - template - Vari** save_varis__(Vari** dest) const { - return save_varis(dest, f_); - } -}; - /** * Higher-order functor that invokes a closure inside a reduce_sum call. */ From 610af2d45ef1134d9ce126aae9d417113b24c74c Mon Sep 17 00:00:00 2001 From: Niko Huurre Date: Tue, 27 Jul 2021 22:54:12 +0300 Subject: [PATCH 17/25] cleanup --- stan/math/prim/fun/eval.hpp | 18 ++- stan/math/prim/fun/value_of.hpp | 7 +- stan/math/prim/functor/closure_adapter.hpp | 118 +++--------------- stan/math/prim/functor/integrate_ode_rk45.hpp | 41 +++++- ...grate_ode_std_vector_interface_adapter.hpp | 68 +--------- stan/math/rev/core/accumulate_adjoints.hpp | 6 +- stan/math/rev/core/count_vars.hpp | 6 +- stan/math/rev/core/deep_copy_vars.hpp | 15 ++- stan/math/rev/core/save_varis.hpp | 4 +- stan/math/rev/core/zero_adjoints.hpp | 3 +- stan/math/rev/functor/integrate_ode_adams.hpp | 45 ++++++- stan/math/rev/functor/integrate_ode_bdf.hpp | 47 ++++++- ..._ode_std_vector_interface_adapter_test.cpp | 2 +- ..._ode_std_vector_interface_adapter_test.cpp | 6 +- 14 files changed, 189 insertions(+), 197 deletions(-) diff --git a/stan/math/prim/fun/eval.hpp b/stan/math/prim/fun/eval.hpp index cc7f6718daf..836381ce8b0 100644 --- a/stan/math/prim/fun/eval.hpp +++ b/stan/math/prim/fun/eval.hpp @@ -7,6 +7,18 @@ namespace stan { namespace math { +/** + * Inputs which have a closure type are forwarded unmodified + * + * @tparam T Input type + * @param[in] arg Input argument + * @return Forwarded input argument + **/ +template * = nullptr> +inline T eval(T&& arg) { + return std::forward(arg); +} + /** * Inputs which have a plain_type equal to the own time are forwarded * unmodified (for Eigen expressions these types are different) @@ -16,7 +28,8 @@ namespace math { * @return Forwarded input argument **/ template , plain_type_t>* = nullptr> + require_same_t, plain_type_t>* = nullptr, + require_not_stan_closure_t* = nullptr> inline T eval(T&& arg) { return std::forward(arg); } @@ -30,7 +43,8 @@ inline T eval(T&& arg) { * @return Eval'd argument **/ template , plain_type_t>* = nullptr> + require_not_same_t, plain_type_t>* = nullptr, + require_not_stan_closure_t* = nullptr> inline decltype(auto) eval(const T& arg) { return arg.eval(); } diff --git a/stan/math/prim/fun/value_of.hpp b/stan/math/prim/fun/value_of.hpp index 73d4ac69e9a..c114c17a4ce 100644 --- a/stan/math/prim/fun/value_of.hpp +++ b/stan/math/prim/fun/value_of.hpp @@ -3,6 +3,7 @@ #include #include +#include #include #include @@ -79,7 +80,11 @@ inline auto value_of(EigMat&& M) { template * = nullptr, require_not_st_arithmetic* = nullptr> inline auto value_of(const F& f) { - return f.value_of__(); + return apply( + [&f](const auto&... s) { + return typename F::ValueOf__(f.f_, eval(value_of(s))...); + }, + f.captures_); } } // namespace math diff --git a/stan/math/prim/functor/closure_adapter.hpp b/stan/math/prim/functor/closure_adapter.hpp index 58e65755118..5b8859a9b7c 100644 --- a/stan/math/prim/functor/closure_adapter.hpp +++ b/stan/math/prim/functor/closure_adapter.hpp @@ -32,31 +32,6 @@ struct base_closure { const auto&... s) { return f_(s..., args..., msgs); }, captures_); } - size_t count_vars__() const { - return apply([this](auto... s) { return count_vars(s...); }, captures_); - } - auto value_of__() const { - return apply( - [this](auto... s) { return ValueOf__(f_, eval(value_of(s))...); }, - captures_); - } - auto deep_copy_vars__() const { - return apply( - [this](auto... s) { return CopyOf__(f_, eval(deep_copy_vars(s))...); }, - captures_); - } - void zero_adjoints__() { - apply([](auto... s) { zero_adjoints(s...); }, captures_); - } - double* accumulate_adjoints__(double* dest) const { - return apply([dest](auto... s) { return accumulate_adjoints(dest, s...); }, - captures_); - } - template - Vari** save_varis__(Vari** dest) const { - return apply([dest](auto... s) { return save_varis(dest, s...); }, - captures_); - } }; /** @@ -79,32 +54,6 @@ struct closure_rng { const auto&... s) { return f_(s..., args..., rng, msgs); }, captures_); } - - size_t count_vars__() const { - return apply([this](auto... s) { return count_vars(s...); }, captures_); - } - auto value_of__() const { - return apply( - [this](auto... s) { return ValueOf__(f_, eval(value_of(s))...); }, - captures_); - } - auto deep_copy_vars__() const { - return apply( - [this](auto... s) { return CopyOf__(f_, eval(deep_copy_vars(s))...); }, - captures_); - } - void zero_adjoints__() { - apply([](auto... s) { zero_adjoints(s...); }, captures_); - } - double* accumulate_adjoints__(double* dest) const { - return apply([dest](auto... s) { return accumulate_adjoints(dest, s...); }, - captures_); - } - template - Vari** save_varis__(Vari** dest) const { - return apply([dest](auto... s) { return save_varis(dest, s...); }, - captures_); - } }; /** @@ -139,31 +88,6 @@ struct closure_lpdf { }, captures_); } - size_t count_vars__() const { - return apply([this](auto... s) { return count_vars(s...); }, captures_); - } - auto value_of__() const { - return apply( - [this](auto... s) { return ValueOf__(f_, eval(value_of(s))...); }, - captures_); - } - auto deep_copy_vars__() const { - return apply( - [this](auto... s) { return CopyOf__(f_, eval(deep_copy_vars(s))...); }, - captures_); - } - void zero_adjoints__() { - apply([](auto... s) { zero_adjoints(s...); }, captures_); - } - double* accumulate_adjoints__(double* dest) const { - return apply([dest](auto... s) { return accumulate_adjoints(dest, s...); }, - captures_); - } - template - Vari** save_varis__(Vari** dest) const { - return apply([dest](auto... s) { return save_varis(dest, s...); }, - captures_); - } }; /** @@ -191,31 +115,6 @@ struct closure_lp { }, captures_); } - size_t count_vars__() const { - return apply([this](auto... s) { return count_vars(s...); }, captures_); - } - auto value_of__() const { - return apply( - [this](auto... s) { return ValueOf__(f_, eval(value_of(s))...); }, - captures_); - } - auto deep_copy_vars__() const { - return apply( - [this](auto... s) { return CopyOf__(f_, eval(deep_copy_vars(s))...); }, - captures_); - } - void zero_adjoints__() { - apply([](auto... s) { zero_adjoints(s...); }, captures_); - } - double* accumulate_adjoints__(double* dest) const { - return apply([dest](auto... s) { return accumulate_adjoints(dest, s...); }, - captures_); - } - template - Vari** save_varis__(Vari** dest) const { - return apply([dest](auto... s) { return save_varis(dest, s...); }, - captures_); - } }; } // namespace internal @@ -232,6 +131,23 @@ struct ode_closure_adapter { } }; +struct integrate_ode_closure_adapter { + template > + auto operator()(const T0& t, const T1& y, std::ostream* msgs, const F& f, + Args... args) const { + return to_vector(f(msgs, t, to_array_1d(y), args...)); + } + + template > + auto operator()(const T0& t, const T1& y, std::ostream* msgs, const F& f, + const std::vector& theta, const std::vector& x, + const std::vector& x_int) const { + return to_vector(f(t, to_array_1d(y), theta, x, x_int, msgs)); + } +}; + /** * Create a closure from a C++ lambda and captures. */ diff --git a/stan/math/prim/functor/integrate_ode_rk45.hpp b/stan/math/prim/functor/integrate_ode_rk45.hpp index 22b18e70dfd..c3ad8aeadc9 100644 --- a/stan/math/prim/functor/integrate_ode_rk45.hpp +++ b/stan/math/prim/functor/integrate_ode_rk45.hpp @@ -2,6 +2,7 @@ #define STAN_MATH_PRIM_FUNCTOR_INTEGRATE_ODE_RK45_HPP #include +#include #include #include #include @@ -10,6 +11,38 @@ namespace stan { namespace math { +namespace internal { + +template * = nullptr> +inline auto integrate_ode_rk45_impl( + const F& f, const std::vector& y0, const T_t0& t0, + const std::vector& ts, const std::vector& theta, + const std::vector& x, const std::vector& x_int, + std::ostream* msgs, double relative_tolerance, double absolute_tolerance, + int max_num_steps) { + internal::integrate_ode_std_vector_interface_adapter f_adapted(f); + return ode_rk45_tol_impl("integrate_ode_rk45", f_adapted, to_vector(y0), t0, + ts, relative_tolerance, absolute_tolerance, + max_num_steps, msgs, theta, x, x_int); +} + +template * = nullptr> +inline auto integrate_ode_rk45_impl( + const F& f, const std::vector& y0, const T_t0& t0, + const std::vector& ts, const std::vector& theta, + const std::vector& x, const std::vector& x_int, + std::ostream* msgs, double relative_tolerance, double absolute_tolerance, + int max_num_steps) { + return ode_rk45_tol_impl("integrate_ode_rk45", + integrate_ode_closure_adapter(), to_vector(y0), t0, + ts, relative_tolerance, absolute_tolerance, + max_num_steps, msgs, f, theta, x, x_int); +} + +} // namespace internal + /** * @deprecated use ode_rk45 */ @@ -21,11 +54,9 @@ inline auto integrate_ode_rk45( const std::vector& x, const std::vector& x_int, std::ostream* msgs = nullptr, double relative_tolerance = 1e-6, double absolute_tolerance = 1e-6, int max_num_steps = 1e6) { - internal::integrate_ode_std_vector_interface_adapter f_adapted(f); - auto y = ode_rk45_tol_impl("integrate_ode_rk45", ode_closure_adapter(), - to_vector(y0), t0, ts, relative_tolerance, - absolute_tolerance, max_num_steps, msgs, f_adapted, - theta, x, x_int); + auto y = internal::integrate_ode_rk45_impl(f, y0, t0, ts, theta, x, x_int, + msgs, relative_tolerance, + absolute_tolerance, max_num_steps); std::vector>> y_converted; diff --git a/stan/math/prim/functor/integrate_ode_std_vector_interface_adapter.hpp b/stan/math/prim/functor/integrate_ode_std_vector_interface_adapter.hpp index 04e20757020..b61d9a9bfa6 100644 --- a/stan/math/prim/functor/integrate_ode_std_vector_interface_adapter.hpp +++ b/stan/math/prim/functor/integrate_ode_std_vector_interface_adapter.hpp @@ -20,76 +20,20 @@ namespace internal { * state as an Eigen::Matrix. The adapter converts to and from these forms * so that the old ODE interfaces can work. */ -template -struct integrate_ode_std_vector_interface_adapter_impl; - -template -struct integrate_ode_std_vector_interface_adapter_impl { - using captured_scalar_t__ = double; - using ValueOf__ - = integrate_ode_std_vector_interface_adapter_impl; - using CopyOf__ - = integrate_ode_std_vector_interface_adapter_impl; +template +struct integrate_ode_std_vector_interface_adapter { const F& f_; - explicit integrate_ode_std_vector_interface_adapter_impl(const F& f) - : f_(f) {} + explicit integrate_ode_std_vector_interface_adapter(const F& f) : f_(f) {} template - auto operator()(std::ostream* msgs, const T0& t, - const Eigen::Matrix& y, - const std::vector& theta, const std::vector& x, + auto operator()(const T0& t, const Eigen::Matrix& y, + std::ostream* msgs, const std::vector& theta, + const std::vector& x, const std::vector& x_int) const { return to_vector(f_(t, to_array_1d(y), theta, x, x_int, msgs)); } - size_t count_vars__() const { return 0; } - auto value_of__() const { return ValueOf__(f_); } - auto deep_copy_vars__() const { return CopyOf__(f_); } - void zero_adjoints__() {} - double* accumulate_adjoints__(double* dest) const { return dest; } - template - Vari** save_varis__(Vari** dest) const { - return dest; - } }; -template -struct integrate_ode_std_vector_interface_adapter_impl { - using captured_scalar_t__ = typename F::captured_scalar_t__; - using ValueOf__ = integrate_ode_std_vector_interface_adapter_impl< - true, typename F::ValueOf__, false>; - using CopyOf__ = integrate_ode_std_vector_interface_adapter_impl< - true, typename F::CopyOf__, false>; - capture_type_t f_; - - explicit integrate_ode_std_vector_interface_adapter_impl(const F& f) - : f_(f) {} - - template - auto operator()(std::ostream* msgs, const T0& t, - const Eigen::Matrix& y, - const std::vector& theta, const std::vector& x, - const std::vector& x_int) const { - return to_vector(f_(msgs, t, to_array_1d(y), theta, x, x_int)); - } - - size_t count_vars__() const { return f_.count_vars__(); } - auto value_of__() const { return ValueOf__(f_.value_of__()); } - auto deep_copy_vars__() const { return CopyOf__(f_.deep_copy_vars__()); } - void zero_adjoints__() { f_.zero_adjoints__(); } - double* accumulate_adjoints__(double* dest) const { - return f_.accumulate_adjoints__(dest); - } - template - Vari** save_varis__(Vari** dest) const { - return f_.save_varis__(dest); - } -}; - -template -using integrate_ode_std_vector_interface_adapter - = integrate_ode_std_vector_interface_adapter_impl::value, - F, true>; - } // namespace internal } // namespace math diff --git a/stan/math/rev/core/accumulate_adjoints.hpp b/stan/math/rev/core/accumulate_adjoints.hpp index 1a95dba46af..f5f9f5c5ba4 100644 --- a/stan/math/rev/core/accumulate_adjoints.hpp +++ b/stan/math/rev/core/accumulate_adjoints.hpp @@ -142,8 +142,10 @@ inline double* accumulate_adjoints(double* dest, EigT&& x, Pargs&&... args) { template *, require_not_st_arithmetic*, typename... Pargs> inline double* accumulate_adjoints(double* dest, F& f, Pargs&&... args) { - return accumulate_adjoints(f.accumulate_adjoints__(dest), - std::forward(args)...); + return accumulate_adjoints( + apply([dest](auto... s) { return accumulate_adjoints(dest, s...); }, + f.captures_), + std::forward(args)...); } /** diff --git a/stan/math/rev/core/count_vars.hpp b/stan/math/rev/core/count_vars.hpp index e6463de54b5..9a3fe214da9 100644 --- a/stan/math/rev/core/count_vars.hpp +++ b/stan/math/rev/core/count_vars.hpp @@ -130,8 +130,10 @@ inline size_t count_vars_impl(size_t count, const var& x, Pargs&&... args) { template *, require_not_st_arithmetic*, typename... Pargs> inline size_t count_vars_impl(size_t count, const F& f, Pargs&&... args) { - return count_vars_impl(count + f.count_vars__(), - std::forward(args)...); + return count_vars_impl( + apply([count](auto... s) { return count_vars_impl(count, s...); }, + f.captures_), + std::forward(args)...); } /** diff --git a/stan/math/rev/core/deep_copy_vars.hpp b/stan/math/rev/core/deep_copy_vars.hpp index 047bb62ef73..32bc6afece7 100644 --- a/stan/math/rev/core/deep_copy_vars.hpp +++ b/stan/math/rev/core/deep_copy_vars.hpp @@ -2,6 +2,7 @@ #define STAN_MATH_REV_CORE_DEEP_COPY_VARS_HPP #include +#include #include #include @@ -19,7 +20,8 @@ namespace math { * @param arg For lvalue references this will be passed by reference. * Otherwise it will be moved. */ -template >> +template >, + typename = require_not_stan_closure_t> inline Arith deep_copy_vars(Arith&& arg) { return std::forward(arg); } @@ -88,10 +90,13 @@ inline auto deep_copy_vars(EigT&& arg) { * @param f A closure containing vars * @return A new closure containing vars */ -template * = nullptr, - require_not_arithmetic_t>* = nullptr> -inline auto deep_copy_vars(F&& f) { - return f.deep_copy_vars__(); +template * = nullptr> +inline auto deep_copy_vars(const F& f) { + return apply( + [&f](const auto&... s) { + return typename F::CopyOf__(f.f_, eval(deep_copy_vars(s))...); + }, + f.captures_); } } // namespace math diff --git a/stan/math/rev/core/save_varis.hpp b/stan/math/rev/core/save_varis.hpp index 6e19d54fa53..2e017a2552c 100644 --- a/stan/math/rev/core/save_varis.hpp +++ b/stan/math/rev/core/save_varis.hpp @@ -138,7 +138,9 @@ inline vari** save_varis(vari** dest, EigT&& x, Pargs&&... args) { template *, require_not_st_arithmetic*, typename... Pargs> inline vari** save_varis(vari** dest, F& f, Pargs&&... args) { - return save_varis(f.save_varis__(dest), std::forward(args)...); + return save_varis( + apply([dest](auto... s) { return save_varis(dest, s...); }, f.captures_), + std::forward(args)...); } /** diff --git a/stan/math/rev/core/zero_adjoints.hpp b/stan/math/rev/core/zero_adjoints.hpp index 0dc706599db..4c1285b3dc6 100644 --- a/stan/math/rev/core/zero_adjoints.hpp +++ b/stan/math/rev/core/zero_adjoints.hpp @@ -67,7 +67,8 @@ inline void zero_adjoints(EigMat& x) { template * = nullptr, require_not_st_arithmetic* = nullptr> inline void zero_adjoints(F& f, Pargs&... args) { - f.zero_adjoints__(); + apply([](auto... s) { zero_adjoints(s...); }, f.captures_); + ; zero_adjoints(args...); } diff --git a/stan/math/rev/functor/integrate_ode_adams.hpp b/stan/math/rev/functor/integrate_ode_adams.hpp index 8d1c2b18c03..44d7b896af3 100644 --- a/stan/math/rev/functor/integrate_ode_adams.hpp +++ b/stan/math/rev/functor/integrate_ode_adams.hpp @@ -3,6 +3,7 @@ #include #include +#include #include #include #include @@ -10,6 +11,42 @@ namespace stan { namespace math { +namespace internal { + +template * = nullptr> +auto integrate_ode_adams_impl(const F& f, const std::vector& y0, + const T_t0& t0, const std::vector& ts, + const std::vector& theta, + const std::vector& x, + const std::vector& x_int, std::ostream* msgs, + double relative_tolerance, + double absolute_tolerance, + long int max_num_steps) { // NOLINT(runtime/int) + internal::integrate_ode_std_vector_interface_adapter f_adapted(f); + return ode_adams_tol_impl("integrate_ode_adams", f_adapted, to_vector(y0), t0, + ts, relative_tolerance, absolute_tolerance, + max_num_steps, msgs, theta, x, x_int); +} + +template * = nullptr> +auto integrate_ode_adams_impl(const F& f, const std::vector& y0, + const T_t0& t0, const std::vector& ts, + const std::vector& theta, + const std::vector& x, + const std::vector& x_int, std::ostream* msgs, + double relative_tolerance, + double absolute_tolerance, + long int max_num_steps) { // NOLINT(runtime/int) + return ode_adams_tol_impl("integrate_ode_adams", + integrate_ode_closure_adapter(), to_vector(y0), t0, + ts, relative_tolerance, absolute_tolerance, + max_num_steps, msgs, f, theta, x, x_int); +} + +} // namespace internal + /** * @deprecated use ode_adams */ @@ -24,11 +61,9 @@ integrate_ode_adams(const F& f, const std::vector& y0, const T_t0& t0, double relative_tolerance = 1e-10, double absolute_tolerance = 1e-10, long int max_num_steps = 1e8) { // NOLINT(runtime/int) - internal::integrate_ode_std_vector_interface_adapter f_adapted(f); - auto y = ode_adams_tol_impl("integrate_ode_adams", ode_closure_adapter(), - to_vector(y0), t0, ts, relative_tolerance, - absolute_tolerance, max_num_steps, msgs, - f_adapted, theta, x, x_int); + auto y = internal::integrate_ode_adams_impl( + f, y0, t0, ts, theta, x, x_int, msgs, relative_tolerance, + absolute_tolerance, max_num_steps); std::vector>> y_converted; diff --git a/stan/math/rev/functor/integrate_ode_bdf.hpp b/stan/math/rev/functor/integrate_ode_bdf.hpp index f25601aa6c0..1f7269a1897 100644 --- a/stan/math/rev/functor/integrate_ode_bdf.hpp +++ b/stan/math/rev/functor/integrate_ode_bdf.hpp @@ -2,6 +2,7 @@ #define STAN_MATH_REV_FUNCTOR_INTEGRATE_ODE_BDF_HPP #include +#include #include #include #include @@ -10,11 +11,47 @@ namespace stan { namespace math { +namespace internal { + +template * = nullptr> +auto integrate_ode_bdf_impl(const F& f, const std::vector& y0, + const T_t0& t0, const std::vector& ts, + const std::vector& theta, + const std::vector& x, + const std::vector& x_int, std::ostream* msgs, + double relative_tolerance, + double absolute_tolerance, + long int max_num_steps) { // NOLINT(runtime/int) + internal::integrate_ode_std_vector_interface_adapter f_adapted(f); + return ode_bdf_tol_impl("integrate_ode_bdf", f_adapted, to_vector(y0), t0, ts, + relative_tolerance, absolute_tolerance, max_num_steps, + msgs, theta, x, x_int); +} + +template * = nullptr> +auto integrate_ode_bdf_impl(const F& f, const std::vector& y0, + const T_t0& t0, const std::vector& ts, + const std::vector& theta, + const std::vector& x, + const std::vector& x_int, std::ostream* msgs, + double relative_tolerance, + double absolute_tolerance, + long int max_num_steps) { // NOLINT(runtime/int) + return ode_bdf_tol_impl("integrate_ode_bdf", integrate_ode_closure_adapter(), + to_vector(y0), t0, ts, relative_tolerance, + absolute_tolerance, max_num_steps, msgs, f, theta, x, + x_int); +} + +} // namespace internal + /** * @deprecated use ode_bdf */ template + typename T_ts, typename = require_not_stan_closure_t> std::vector>> integrate_ode_bdf(const F& f, const std::vector& y0, const T_t0& t0, const std::vector& ts, @@ -24,11 +61,9 @@ integrate_ode_bdf(const F& f, const std::vector& y0, const T_t0& t0, double relative_tolerance = 1e-10, double absolute_tolerance = 1e-10, long int max_num_steps = 1e8) { // NOLINT(runtime/int) - internal::integrate_ode_std_vector_interface_adapter f_adapted(f); - auto y = ode_bdf_tol_impl("integrate_ode_bdf", ode_closure_adapter(), - to_vector(y0), t0, ts, relative_tolerance, - absolute_tolerance, max_num_steps, msgs, f_adapted, - theta, x, x_int); + auto y = internal::integrate_ode_bdf_impl(f, y0, t0, ts, theta, x, x_int, + msgs, relative_tolerance, + absolute_tolerance, max_num_steps); std::vector>> y_converted; diff --git a/test/unit/math/prim/functor/integrate_ode_std_vector_interface_adapter_test.cpp b/test/unit/math/prim/functor/integrate_ode_std_vector_interface_adapter_test.cpp index 2e0f25b31ad..1d7f484a356 100644 --- a/test/unit/math/prim/functor/integrate_ode_std_vector_interface_adapter_test.cpp +++ b/test/unit/math/prim/functor/integrate_ode_std_vector_interface_adapter_test.cpp @@ -21,7 +21,7 @@ TEST(StanMath, check_values) { Eigen::VectorXd out1 = stan::math::to_vector(harm_osc(t, y, theta, x, x_int, nullptr)); Eigen::VectorXd out2 - = harm_osc_adapted(nullptr, t, stan::math::to_vector(y), theta, x, x_int); + = harm_osc_adapted(t, stan::math::to_vector(y), nullptr, theta, x, x_int); EXPECT_MATRIX_FLOAT_EQ(out1, out2); } diff --git a/test/unit/math/rev/functor/integrate_ode_std_vector_interface_adapter_test.cpp b/test/unit/math/rev/functor/integrate_ode_std_vector_interface_adapter_test.cpp index effa39ac332..f8611f7f57a 100644 --- a/test/unit/math/rev/functor/integrate_ode_std_vector_interface_adapter_test.cpp +++ b/test/unit/math/rev/functor/integrate_ode_std_vector_interface_adapter_test.cpp @@ -22,7 +22,7 @@ TEST(StanMathRev, vd) { Eigen::Matrix out1 = stan::math::to_vector(harm_osc(t, y, theta, x, x_int, nullptr)); Eigen::Matrix out2 - = harm_osc_adapted(nullptr, t, stan::math::to_vector(y), theta, x, x_int); + = harm_osc_adapted(t, stan::math::to_vector(y), nullptr, theta, x, x_int); stan::math::sum(out1).grad(); Eigen::VectorXd adjs1(theta.size()); @@ -58,7 +58,7 @@ TEST(StanMathRev, dv) { Eigen::Matrix out1 = stan::math::to_vector(harm_osc(t, y, theta, x, x_int, nullptr)); Eigen::Matrix out2 - = harm_osc_adapted(nullptr, t, stan::math::to_vector(y), theta, x, x_int); + = harm_osc_adapted(t, stan::math::to_vector(y), nullptr, theta, x, x_int); stan::math::sum(out1).grad(); Eigen::VectorXd adjs1(y.size()); @@ -94,7 +94,7 @@ TEST(StanMathRev, vv) { Eigen::Matrix out1 = stan::math::to_vector(harm_osc(t, y, theta, x, x_int, nullptr)); Eigen::Matrix out2 - = harm_osc_adapted(nullptr, t, stan::math::to_vector(y), theta, x, x_int); + = harm_osc_adapted(t, stan::math::to_vector(y), nullptr, theta, x, x_int); stan::math::sum(out1).grad(); Eigen::VectorXd adjs_theta_1(theta.size()); From 880e270ef34875525fd1bf0303b61301ed87f460 Mon Sep 17 00:00:00 2001 From: Niko Huurre Date: Tue, 27 Jul 2021 23:21:56 +0300 Subject: [PATCH 18/25] remove empty statement --- stan/math/rev/core/zero_adjoints.hpp | 1 - 1 file changed, 1 deletion(-) diff --git a/stan/math/rev/core/zero_adjoints.hpp b/stan/math/rev/core/zero_adjoints.hpp index 4c1285b3dc6..c3f89673ebd 100644 --- a/stan/math/rev/core/zero_adjoints.hpp +++ b/stan/math/rev/core/zero_adjoints.hpp @@ -68,7 +68,6 @@ template * = nullptr, require_not_st_arithmetic* = nullptr> inline void zero_adjoints(F& f, Pargs&... args) { apply([](auto... s) { zero_adjoints(s...); }, f.captures_); - ; zero_adjoints(args...); } From 4f1f6eb580b831892657c596880a6f3b202f418d Mon Sep 17 00:00:00 2001 From: Niko Huurre Date: Wed, 28 Jul 2021 10:52:28 +0300 Subject: [PATCH 19/25] fix includes --- stan/math/prim/fun/value_of.hpp | 1 + stan/math/prim/functor/closure_adapter.hpp | 12 +----------- stan/math/rev/core/accumulate_adjoints.hpp | 1 + stan/math/rev/core/count_vars.hpp | 1 + stan/math/rev/core/deep_copy_vars.hpp | 1 + stan/math/rev/core/save_varis.hpp | 1 + stan/math/rev/core/zero_adjoints.hpp | 1 + 7 files changed, 7 insertions(+), 11 deletions(-) diff --git a/stan/math/prim/fun/value_of.hpp b/stan/math/prim/fun/value_of.hpp index c114c17a4ce..6f9ef1b6326 100644 --- a/stan/math/prim/fun/value_of.hpp +++ b/stan/math/prim/fun/value_of.hpp @@ -4,6 +4,7 @@ #include #include #include +#include #include #include diff --git a/stan/math/prim/functor/closure_adapter.hpp b/stan/math/prim/functor/closure_adapter.hpp index 5b8859a9b7c..f2a8c20ed12 100644 --- a/stan/math/prim/functor/closure_adapter.hpp +++ b/stan/math/prim/functor/closure_adapter.hpp @@ -2,7 +2,6 @@ #define STAN_MATH_PRIM_FUNCTOR_CLOSURE_ADAPTER_HPP #include -#include #include #include #include @@ -132,20 +131,11 @@ struct ode_closure_adapter { }; struct integrate_ode_closure_adapter { - template > + template auto operator()(const T0& t, const T1& y, std::ostream* msgs, const F& f, Args... args) const { return to_vector(f(msgs, t, to_array_1d(y), args...)); } - - template > - auto operator()(const T0& t, const T1& y, std::ostream* msgs, const F& f, - const std::vector& theta, const std::vector& x, - const std::vector& x_int) const { - return to_vector(f(t, to_array_1d(y), theta, x, x_int, msgs)); - } }; /** diff --git a/stan/math/rev/core/accumulate_adjoints.hpp b/stan/math/rev/core/accumulate_adjoints.hpp index f5f9f5c5ba4..ddcd9ae4d00 100644 --- a/stan/math/rev/core/accumulate_adjoints.hpp +++ b/stan/math/rev/core/accumulate_adjoints.hpp @@ -2,6 +2,7 @@ #define STAN_MATH_REV_CORE_ACCUMULATE_ADJOINTS_HPP #include +#include #include #include diff --git a/stan/math/rev/core/count_vars.hpp b/stan/math/rev/core/count_vars.hpp index 9a3fe214da9..466d8609214 100644 --- a/stan/math/rev/core/count_vars.hpp +++ b/stan/math/rev/core/count_vars.hpp @@ -2,6 +2,7 @@ #define STAN_MATH_REV_CORE_COUNT_VARS_HPP #include +#include #include #include diff --git a/stan/math/rev/core/deep_copy_vars.hpp b/stan/math/rev/core/deep_copy_vars.hpp index 32bc6afece7..5fa3260b459 100644 --- a/stan/math/rev/core/deep_copy_vars.hpp +++ b/stan/math/rev/core/deep_copy_vars.hpp @@ -3,6 +3,7 @@ #include #include +#include #include #include diff --git a/stan/math/rev/core/save_varis.hpp b/stan/math/rev/core/save_varis.hpp index 2e017a2552c..37d246d81dc 100644 --- a/stan/math/rev/core/save_varis.hpp +++ b/stan/math/rev/core/save_varis.hpp @@ -1,6 +1,7 @@ #ifndef STAN_MATH_REV_CORE_SAVE_VARIS_HPP #define STAN_MATH_REV_CORE_SAVE_VARIS_HPP +#include #include #include #include diff --git a/stan/math/rev/core/zero_adjoints.hpp b/stan/math/rev/core/zero_adjoints.hpp index c3f89673ebd..82e49cf3d62 100644 --- a/stan/math/rev/core/zero_adjoints.hpp +++ b/stan/math/rev/core/zero_adjoints.hpp @@ -3,6 +3,7 @@ #include #include +#include #include namespace stan { From 75f6d308204361a00175ccef8c7b5dd88ff6e1c8 Mon Sep 17 00:00:00 2001 From: Niko Huurre Date: Wed, 28 Jul 2021 12:25:42 +0300 Subject: [PATCH 20/25] missing include --- stan/math/prim/functor/closure_adapter.hpp | 1 + 1 file changed, 1 insertion(+) diff --git a/stan/math/prim/functor/closure_adapter.hpp b/stan/math/prim/functor/closure_adapter.hpp index f2a8c20ed12..33b20d456e9 100644 --- a/stan/math/prim/functor/closure_adapter.hpp +++ b/stan/math/prim/functor/closure_adapter.hpp @@ -3,6 +3,7 @@ #include #include +#include #include #include From 0a609dba1de82b7715bee98d1b0ae676147ca658 Mon Sep 17 00:00:00 2001 From: Steve Bronder Date: Tue, 10 Aug 2021 15:52:22 -0400 Subject: [PATCH 21/25] fixup names and references logic for closure classes --- stan/math/prim/fun/eval.hpp | 18 +--- stan/math/prim/fun/value_of.hpp | 2 +- stan/math/prim/functor/closure_adapter.hpp | 112 ++++++++++---------- stan/math/prim/meta/is_stan_closure.hpp | 36 ++++--- stan/math/prim/meta/promote_scalar_type.hpp | 2 +- stan/math/rev/core/deep_copy_vars.hpp | 2 +- 6 files changed, 84 insertions(+), 88 deletions(-) diff --git a/stan/math/prim/fun/eval.hpp b/stan/math/prim/fun/eval.hpp index 836381ce8b0..cc7f6718daf 100644 --- a/stan/math/prim/fun/eval.hpp +++ b/stan/math/prim/fun/eval.hpp @@ -7,18 +7,6 @@ namespace stan { namespace math { -/** - * Inputs which have a closure type are forwarded unmodified - * - * @tparam T Input type - * @param[in] arg Input argument - * @return Forwarded input argument - **/ -template * = nullptr> -inline T eval(T&& arg) { - return std::forward(arg); -} - /** * Inputs which have a plain_type equal to the own time are forwarded * unmodified (for Eigen expressions these types are different) @@ -28,8 +16,7 @@ inline T eval(T&& arg) { * @return Forwarded input argument **/ template , plain_type_t>* = nullptr, - require_not_stan_closure_t* = nullptr> + require_same_t, plain_type_t>* = nullptr> inline T eval(T&& arg) { return std::forward(arg); } @@ -43,8 +30,7 @@ inline T eval(T&& arg) { * @return Eval'd argument **/ template , plain_type_t>* = nullptr, - require_not_stan_closure_t* = nullptr> + require_not_same_t, plain_type_t>* = nullptr> inline decltype(auto) eval(const T& arg) { return arg.eval(); } diff --git a/stan/math/prim/fun/value_of.hpp b/stan/math/prim/fun/value_of.hpp index 6f9ef1b6326..0cc7c5245bd 100644 --- a/stan/math/prim/fun/value_of.hpp +++ b/stan/math/prim/fun/value_of.hpp @@ -83,7 +83,7 @@ template * = nullptr, inline auto value_of(const F& f) { return apply( [&f](const auto&... s) { - return typename F::ValueOf__(f.f_, eval(value_of(s))...); + return typename F::partials_closure_t_(f.f_, eval(value_of(s))...); }, f.captures_); } diff --git a/stan/math/prim/functor/closure_adapter.hpp b/stan/math/prim/functor/closure_adapter.hpp index 33b20d456e9..3a886ae17e0 100644 --- a/stan/math/prim/functor/closure_adapter.hpp +++ b/stan/math/prim/functor/closure_adapter.hpp @@ -16,15 +16,16 @@ namespace internal { */ template struct base_closure { - using captured_scalar_t__ = return_type_t; - using ValueOf__ + using return_scalar_t_ = return_type_t; + /*The base closure with `Ts` as the non-expression partials of `Ts`*/ + using partials_closure_t_ = base_closure())))...>; - using CopyOf__ = base_closure; - F f_; - std::tuple...> captures_; - - explicit base_closure(const F& f, const Ts&... args) - : f_(f), captures_(args...) {} + using Base_ = base_closure; + std::decay_t f_; + std::tuple...> captures_; + template * = nullptr, typename... Args> + explicit base_closure(FF&& f, Args&&... args) + : f_(std::forward(f)), captures_(std::forward(args)...) {} template auto operator()(std::ostream* msgs, const Args&... args) const { @@ -39,19 +40,20 @@ struct base_closure { */ template struct closure_rng { - using captured_scalar_t__ = double; - using ValueOf__ = closure_rng; - using CopyOf__ = closure_rng; - F f_; - std::tuple...> captures_; + using return_scalar_t_ = double; + using partials_closure_t_ = closure_rng; + using Base_ = closure_rng; + std::decay_t f_; + std::tuple...> captures_; - explicit closure_rng(const F& f, const Ts&... args) - : f_(f), captures_(args...) {} + template * = nullptr, typename... Args> + explicit closure_rng(FF&& f, Args&&... args) + : f_(std::forward(f)), captures_(std::forward(args)...) {} template auto operator()(Rng& rng, std::ostream* msgs, const Args&... args) const { return apply([this, &rng, msgs, &args...]( - const auto&... s) { return f_(s..., args..., rng, msgs); }, + const auto&... s) { return this->f_(s..., args..., rng, msgs); }, captures_); } }; @@ -61,21 +63,22 @@ struct closure_rng { */ template struct closure_lpdf { - using captured_scalar_t__ = return_type_t; - using ValueOf__ = closure_lpdf; - using CopyOf__ = closure_lpdf; - F f_; - std::tuple...> captures_; + using return_scalar_t_ = return_type_t; + using partials_closure_t_ = closure_lpdf; + using Base_ = closure_lpdf; + std::decay_t f_; + std::tuple...> captures_; - explicit closure_lpdf(const F& f, const Ts&... args) - : f_(f), captures_(args...) {} + template * = nullptr, typename... Args> + explicit closure_lpdf(FF&& f, Args&&... args) + : f_(std::forward(f)), captures_(std::forward(args)...) {} template auto with_propto() { return apply( [this](const auto&... args) { return closure_lpdf < Propto && propto, true, F, - Ts... > (f_, args...); + Ts... > (this->f_, args...); }, captures_); } @@ -84,7 +87,7 @@ struct closure_lpdf { auto operator()(std::ostream* msgs, const Args&... args) const { return apply( [this, msgs, &args...](const auto&... s) { - return f_.template operator()(s..., args..., msgs); + return this->f_.template operator()(s..., args..., msgs); }, captures_); } @@ -95,14 +98,15 @@ struct closure_lpdf { */ template struct closure_lp { - using captured_scalar_t__ = return_type_t; - using ValueOf__ = closure_lp; - using CopyOf__ = closure_lp; - F f_; - std::tuple...> captures_; + using return_scalar_t_ = return_type_t; + using partials_closure_t_ = closure_lp; + using Base_ = closure_lp; + std::decay_t f_; + std::tuple...> captures_; - explicit closure_lp(const F& f, const Ts&... args) - : f_(f), captures_(args...) {} + template * = nullptr, typename... Args> + explicit closure_lp(FF&& f, Args&&... args) + : f_(std::forward(f)), captures_(std::forward(args)...) {} template @@ -125,50 +129,50 @@ struct closure_lp { */ struct ode_closure_adapter { template - auto operator()(const T0& t, const T1& y, std::ostream* msgs, const F& f, - Args... args) const { - return f(msgs, t, y, args...); + auto operator()(const T0& t, const T1& y, std::ostream* msgs, F&& f, + Args&&... args) const { + return std::forward(f)(msgs, t, y, std::forward(args)...); } }; struct integrate_ode_closure_adapter { template - auto operator()(const T0& t, const T1& y, std::ostream* msgs, const F& f, - Args... args) const { - return to_vector(f(msgs, t, to_array_1d(y), args...)); + auto operator()(const T0& t, const T1& y, std::ostream* msgs, F&& f, + Args&&... args) const { + return to_vector(std::forward(f)(msgs, t, to_array_1d(y), std::forward(args)...)); } }; /** * Create a closure from a C++ lambda and captures. */ -template -auto from_lambda(const F& f, const Ts&... a) { - return internal::base_closure(f, a...); +template +auto from_lambda(F&& f, Args&&... args) { + return internal::base_closure(std::forward(f), std::forward(args)...); } /** * Create a closure from an rng functor. */ -template -auto rng_from_lambda(const F& f, const Ts&... a) { - return internal::closure_rng(f, a...); +template +auto rng_from_lambda(F&& f, Args&&... args) { + return internal::closure_rng(std::forward(f), std::forward(args)...); } /** * Create a closure from an lpdf functor. */ -template -auto lpdf_from_lambda(const F& f, const Ts&... a) { - return internal::closure_lpdf(f, a...); +template +auto lpdf_from_lambda(F&& f, Args&&... args) { + return internal::closure_lpdf(std::forward(f), std::forward(args)...); } /** * Create a closure from a functor that needs access to logprob accumulator. */ -template -auto lp_from_lambda(const F& f, const Ts&... args) { - return internal::closure_lp(f, args...); +template +auto lp_from_lambda(F&& f, Args&&... args) { + return internal::closure_lp(std::forward(f), std::forward(args)...); } /** @@ -177,10 +181,10 @@ auto lp_from_lambda(const F& f, const Ts&... args) { struct reduce_sum_closure_adapter { template auto operator()(const std::vector& sub_slice, std::size_t start, - std::size_t end, std::ostream* msgs, const F& f, - Args... args) const { - return f(msgs, sub_slice, start + error_index::value, - end + error_index::value, args...); + std::size_t end, std::ostream* msgs, F&& f, + Args&&... args) const { + return std::forward(f)(msgs, sub_slice, start + error_index::value, + end + error_index::value, std::forward(args)...); } }; diff --git a/stan/math/prim/meta/is_stan_closure.hpp b/stan/math/prim/meta/is_stan_closure.hpp index 30f4db0861a..df41410e808 100644 --- a/stan/math/prim/meta/is_stan_closure.hpp +++ b/stan/math/prim/meta/is_stan_closure.hpp @@ -18,15 +18,16 @@ template struct is_stan_closure : std::false_type {}; template -struct is_stan_closure> +struct is_stan_closure::return_scalar_t_>> : std::true_type {}; +STAN_ADD_REQUIRE_UNARY(stan_closure, is_stan_closure, general_types); + template -struct scalar_type> { - using type = typename T::captured_scalar_t__; +struct scalar_type> { + using type = typename std::decay_t::return_scalar_t_; }; -STAN_ADD_REQUIRE_UNARY(stan_closure, is_stan_closure, general_types); template struct fn_return_type { @@ -34,8 +35,8 @@ struct fn_return_type { }; template -struct fn_return_type> { - using type = typename T::captured_scalar_t__; +struct fn_return_type::return_scalar_t_>> { + using type = typename std::decay_t::return_scalar_t_; }; /** @@ -52,23 +53,28 @@ using fn_return_type_t = return_type_t::type, Args...>; template -struct capture_type; +struct closure_return_type; template -struct capture_type { - using type = const T&; +struct closure_return_type { + using type = const std::decay_t&; }; template -struct capture_type>> { +struct closure_return_type> { using type = std::remove_reference_t; }; template -struct capture_type>> { - using type = typename std::remove_reference_t::CopyOf__; +struct closure_return_type> { + using type = typename std::remove_reference_t::Base_; +}; + +template +struct scalar_type> { + using type = typename std::decay_t::return_scalar_t_; }; /** @@ -78,7 +84,7 @@ struct capture_type -using capture_type_t = typename capture_type::type; +using closure_return_type_t = typename closure_return_type::type; } // namespace stan diff --git a/stan/math/prim/meta/promote_scalar_type.hpp b/stan/math/prim/meta/promote_scalar_type.hpp index 2f2a11382d5..a1574d81149 100644 --- a/stan/math/prim/meta/promote_scalar_type.hpp +++ b/stan/math/prim/meta/promote_scalar_type.hpp @@ -108,7 +108,7 @@ struct promote_scalar_type> { * The promoted type. */ using type = typename std::conditional::value, F, - typename F::ValueOf__>::type; + typename F::partials_closure_t_>::type; }; template diff --git a/stan/math/rev/core/deep_copy_vars.hpp b/stan/math/rev/core/deep_copy_vars.hpp index 5fa3260b459..b8c3dd077e8 100644 --- a/stan/math/rev/core/deep_copy_vars.hpp +++ b/stan/math/rev/core/deep_copy_vars.hpp @@ -95,7 +95,7 @@ template * = nullptr> inline auto deep_copy_vars(const F& f) { return apply( [&f](const auto&... s) { - return typename F::CopyOf__(f.f_, eval(deep_copy_vars(s))...); + return typename F::Base_(f.f_, eval(deep_copy_vars(s))...); }, f.captures_); } From e0f6145a1066d8575ff5cb3bc854acb786e848b0 Mon Sep 17 00:00:00 2001 From: Steve Bronder Date: Tue, 10 Aug 2021 16:21:01 -0400 Subject: [PATCH 22/25] remove fn_return_type --- stan/math/prim/functor/integrate_ode_rk45.hpp | 2 +- stan/math/prim/functor/ode_ckrk.hpp | 4 +-- stan/math/prim/functor/ode_rk45.hpp | 4 +-- stan/math/prim/meta/is_stan_closure.hpp | 29 ------------------- stan/math/prim/meta/return_type.hpp | 6 +++- stan/math/rev/functor/integrate_ode_adams.hpp | 4 +-- stan/math/rev/functor/integrate_ode_bdf.hpp | 4 +-- 7 files changed, 14 insertions(+), 39 deletions(-) diff --git a/stan/math/prim/functor/integrate_ode_rk45.hpp b/stan/math/prim/functor/integrate_ode_rk45.hpp index c3ad8aeadc9..47fac5b4377 100644 --- a/stan/math/prim/functor/integrate_ode_rk45.hpp +++ b/stan/math/prim/functor/integrate_ode_rk45.hpp @@ -58,7 +58,7 @@ inline auto integrate_ode_rk45( msgs, relative_tolerance, absolute_tolerance, max_num_steps); - std::vector>> + std::vector>> y_converted; y_converted.reserve(y.size()); for (size_t i = 0; i < y.size(); ++i) diff --git a/stan/math/prim/functor/ode_ckrk.hpp b/stan/math/prim/functor/ode_ckrk.hpp index 44d455ea520..3693aecefa2 100644 --- a/stan/math/prim/functor/ode_ckrk.hpp +++ b/stan/math/prim/functor/ode_ckrk.hpp @@ -195,7 +195,7 @@ ode_ckrk_tol_impl(const char* function_name, const F& f, const T_y0& y0_arg, */ template * = nullptr> -std::vector, +std::vector, Eigen::Dynamic, 1>> ode_ckrk_tol(const F& f, const T_y0& y0_arg, T_t0 t0, const std::vector& ts, double relative_tolerance, @@ -241,7 +241,7 @@ ode_ckrk_tol(const F& f, const T_y0& y0_arg, T_t0 t0, */ template * = nullptr> -std::vector, +std::vector, Eigen::Dynamic, 1>> ode_ckrk(const F& f, const T_y0& y0, T_t0 t0, const std::vector& ts, std::ostream* msgs, const Args&... args) { diff --git a/stan/math/prim/functor/ode_rk45.hpp b/stan/math/prim/functor/ode_rk45.hpp index 068a267e22d..521e66cface 100644 --- a/stan/math/prim/functor/ode_rk45.hpp +++ b/stan/math/prim/functor/ode_rk45.hpp @@ -198,7 +198,7 @@ ode_rk45_tol_impl(const char* function_name, const F& f, const T_y0& y0_arg, */ template * = nullptr> -std::vector, +std::vector, Eigen::Dynamic, 1>> ode_rk45_tol(const F& f, const T_y0& y0_arg, T_t0 t0, const std::vector& ts, double relative_tolerance, @@ -244,7 +244,7 @@ ode_rk45_tol(const F& f, const T_y0& y0_arg, T_t0 t0, */ template * = nullptr> -std::vector, +std::vector, Eigen::Dynamic, 1>> ode_rk45(const F& f, const T_y0& y0, T_t0 t0, const std::vector& ts, std::ostream* msgs, const Args&... args) { diff --git a/stan/math/prim/meta/is_stan_closure.hpp b/stan/math/prim/meta/is_stan_closure.hpp index df41410e808..46f985a168f 100644 --- a/stan/math/prim/meta/is_stan_closure.hpp +++ b/stan/math/prim/meta/is_stan_closure.hpp @@ -28,30 +28,6 @@ struct scalar_type> { using type = typename std::decay_t::return_scalar_t_; }; - -template -struct fn_return_type { - using type = double; -}; - -template -struct fn_return_type::return_scalar_t_>> { - using type = typename std::decay_t::return_scalar_t_; -}; - -/** - * Convenience type for the return type of the specified template - * parameters. - * - * @tparam F callable type - * @tparam Ts sequence of types - * @see return_type - * @ingroup type_trait - */ -template -using fn_return_type_t - = return_type_t::type, Args...>; - template struct closure_return_type; @@ -72,11 +48,6 @@ struct closure_return_type::Base_; }; -template -struct scalar_type> { - using type = typename std::decay_t::return_scalar_t_; -}; - /** * Type for things captured either by const reference or by copy. * diff --git a/stan/math/prim/meta/return_type.hpp b/stan/math/prim/meta/return_type.hpp index 3f9afa4f504..1ed45bd8f40 100644 --- a/stan/math/prim/meta/return_type.hpp +++ b/stan/math/prim/meta/return_type.hpp @@ -4,6 +4,7 @@ #include #include #include +#include #include #include #include @@ -116,7 +117,10 @@ using row_vector_return_t = Eigen::Matrix, 1, -1>; */ template struct scalar_lub { - using type = promote_args_t; + using type = + std::conditional_t::value && is_stan_scalar::value, + promote_args_t, std::conditional_t::value, T1, + std::conditional_t::value, T2, double>>>; }; template diff --git a/stan/math/rev/functor/integrate_ode_adams.hpp b/stan/math/rev/functor/integrate_ode_adams.hpp index 44d7b896af3..1f2fee91cb0 100644 --- a/stan/math/rev/functor/integrate_ode_adams.hpp +++ b/stan/math/rev/functor/integrate_ode_adams.hpp @@ -52,7 +52,7 @@ auto integrate_ode_adams_impl(const F& f, const std::vector& y0, */ template -std::vector>> +std::vector>> integrate_ode_adams(const F& f, const std::vector& y0, const T_t0& t0, const std::vector& ts, const std::vector& theta, @@ -65,7 +65,7 @@ integrate_ode_adams(const F& f, const std::vector& y0, const T_t0& t0, f, y0, t0, ts, theta, x, x_int, msgs, relative_tolerance, absolute_tolerance, max_num_steps); - std::vector>> + std::vector>> y_converted; for (size_t i = 0; i < y.size(); ++i) y_converted.push_back(to_array_1d(y[i])); diff --git a/stan/math/rev/functor/integrate_ode_bdf.hpp b/stan/math/rev/functor/integrate_ode_bdf.hpp index 1f7269a1897..4e25abd1b1b 100644 --- a/stan/math/rev/functor/integrate_ode_bdf.hpp +++ b/stan/math/rev/functor/integrate_ode_bdf.hpp @@ -52,7 +52,7 @@ auto integrate_ode_bdf_impl(const F& f, const std::vector& y0, */ template > -std::vector>> +std::vector>> integrate_ode_bdf(const F& f, const std::vector& y0, const T_t0& t0, const std::vector& ts, const std::vector& theta, @@ -65,7 +65,7 @@ integrate_ode_bdf(const F& f, const std::vector& y0, const T_t0& t0, msgs, relative_tolerance, absolute_tolerance, max_num_steps); - std::vector>> + std::vector>> y_converted; for (size_t i = 0; i < y.size(); ++i) y_converted.push_back(to_array_1d(y[i])); From 9436a182713344526caeba0d7735c1a47ce9953e Mon Sep 17 00:00:00 2001 From: Stan Jenkins Date: Thu, 12 Aug 2021 09:59:24 +0000 Subject: [PATCH 23/25] [Jenkins] auto-formatting by clang-format version 6.0.0-1ubuntu2~16.04.1 (tags/RELEASE_600/final) --- stan/math/prim/functor/closure_adapter.hpp | 26 ++++++++++++++-------- stan/math/prim/meta/is_stan_closure.hpp | 6 ++--- stan/math/prim/meta/return_type.hpp | 10 +++++---- 3 files changed, 25 insertions(+), 17 deletions(-) diff --git a/stan/math/prim/functor/closure_adapter.hpp b/stan/math/prim/functor/closure_adapter.hpp index 3a886ae17e0..5d039db9107 100644 --- a/stan/math/prim/functor/closure_adapter.hpp +++ b/stan/math/prim/functor/closure_adapter.hpp @@ -52,9 +52,11 @@ struct closure_rng { template auto operator()(Rng& rng, std::ostream* msgs, const Args&... args) const { - return apply([this, &rng, msgs, &args...]( - const auto&... s) { return this->f_(s..., args..., rng, msgs); }, - captures_); + return apply( + [this, &rng, msgs, &args...](const auto&... s) { + return this->f_(s..., args..., rng, msgs); + }, + captures_); } }; @@ -139,7 +141,8 @@ struct integrate_ode_closure_adapter { template auto operator()(const T0& t, const T1& y, std::ostream* msgs, F&& f, Args&&... args) const { - return to_vector(std::forward(f)(msgs, t, to_array_1d(y), std::forward(args)...)); + return to_vector(std::forward(f)(msgs, t, to_array_1d(y), + std::forward(args)...)); } }; @@ -148,7 +151,8 @@ struct integrate_ode_closure_adapter { */ template auto from_lambda(F&& f, Args&&... args) { - return internal::base_closure(std::forward(f), std::forward(args)...); + return internal::base_closure(std::forward(f), + std::forward(args)...); } /** @@ -156,7 +160,8 @@ auto from_lambda(F&& f, Args&&... args) { */ template auto rng_from_lambda(F&& f, Args&&... args) { - return internal::closure_rng(std::forward(f), std::forward(args)...); + return internal::closure_rng(std::forward(f), + std::forward(args)...); } /** @@ -164,7 +169,8 @@ auto rng_from_lambda(F&& f, Args&&... args) { */ template auto lpdf_from_lambda(F&& f, Args&&... args) { - return internal::closure_lpdf(std::forward(f), std::forward(args)...); + return internal::closure_lpdf( + std::forward(f), std::forward(args)...); } /** @@ -172,7 +178,8 @@ auto lpdf_from_lambda(F&& f, Args&&... args) { */ template auto lp_from_lambda(F&& f, Args&&... args) { - return internal::closure_lp(std::forward(f), std::forward(args)...); + return internal::closure_lp( + std::forward(f), std::forward(args)...); } /** @@ -184,7 +191,8 @@ struct reduce_sum_closure_adapter { std::size_t end, std::ostream* msgs, F&& f, Args&&... args) const { return std::forward(f)(msgs, sub_slice, start + error_index::value, - end + error_index::value, std::forward(args)...); + end + error_index::value, + std::forward(args)...); } }; diff --git a/stan/math/prim/meta/is_stan_closure.hpp b/stan/math/prim/meta/is_stan_closure.hpp index 46f985a168f..3ff256c3de0 100644 --- a/stan/math/prim/meta/is_stan_closure.hpp +++ b/stan/math/prim/meta/is_stan_closure.hpp @@ -37,14 +37,12 @@ struct closure_return_type { }; template -struct closure_return_type> { +struct closure_return_type> { using type = std::remove_reference_t; }; template -struct closure_return_type> { +struct closure_return_type> { using type = typename std::remove_reference_t::Base_; }; diff --git a/stan/math/prim/meta/return_type.hpp b/stan/math/prim/meta/return_type.hpp index 1ed45bd8f40..0f41544ef26 100644 --- a/stan/math/prim/meta/return_type.hpp +++ b/stan/math/prim/meta/return_type.hpp @@ -117,10 +117,12 @@ using row_vector_return_t = Eigen::Matrix, 1, -1>; */ template struct scalar_lub { - using type = - std::conditional_t::value && is_stan_scalar::value, - promote_args_t, std::conditional_t::value, T1, - std::conditional_t::value, T2, double>>>; + using type = std::conditional_t< + is_stan_scalar::value && is_stan_scalar::value, + promote_args_t, + std::conditional_t< + is_stan_scalar::value, T1, + std::conditional_t::value, T2, double>>>; }; template From 2b2bee208153d89f11b89b15bd206e667cd769c5 Mon Sep 17 00:00:00 2001 From: Steve Bronder Date: Fri, 13 Aug 2021 16:37:20 -0400 Subject: [PATCH 24/25] fix headers --- stan/math/prim/functor/closure_adapter.hpp | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/stan/math/prim/functor/closure_adapter.hpp b/stan/math/prim/functor/closure_adapter.hpp index 5d039db9107..5e6b7c73c1f 100644 --- a/stan/math/prim/functor/closure_adapter.hpp +++ b/stan/math/prim/functor/closure_adapter.hpp @@ -1,9 +1,7 @@ #ifndef STAN_MATH_PRIM_FUNCTOR_CLOSURE_ADAPTER_HPP #define STAN_MATH_PRIM_FUNCTOR_CLOSURE_ADAPTER_HPP -#include -#include -#include +#include #include #include From 03ab504917b3af030f7226e2abb80672f0a728de Mon Sep 17 00:00:00 2001 From: Niko Huurre Date: Tue, 7 Sep 2021 13:20:56 +0300 Subject: [PATCH 25/25] some docs, a test --- stan/math/prim/functor/closure_adapter.hpp | 32 +++++++++++++++---- .../math/prim/meta/is_stan_closure_test.cpp | 18 +++++++++++ 2 files changed, 44 insertions(+), 6 deletions(-) create mode 100644 test/unit/math/prim/meta/is_stan_closure_test.cpp diff --git a/stan/math/prim/functor/closure_adapter.hpp b/stan/math/prim/functor/closure_adapter.hpp index 5e6b7c73c1f..a708d27c16c 100644 --- a/stan/math/prim/functor/closure_adapter.hpp +++ b/stan/math/prim/functor/closure_adapter.hpp @@ -11,6 +11,10 @@ namespace internal { /** * A closure that wraps a C++ lambda and captures values. + * + * @tparam Ref if true values are captured by reference + * @tparam F the lambda functor type + * @tparam Ts types of the captured values */ template struct base_closure { @@ -27,14 +31,20 @@ struct base_closure { template auto operator()(std::ostream* msgs, const Args&... args) const { - return apply([this, msgs, &args...]( - const auto&... s) { return f_(s..., args..., msgs); }, - captures_); + return apply( + [this, msgs, &args...](const auto&... s) { + return this->f_(s..., args..., msgs); + }, + captures_); } }; /** * A closure that takes rng argument. + * + * @tparam Ref if true values are captured by reference + * @tparam F the lambda functor type + * @tparam Ts types of the captured values */ template struct closure_rng { @@ -59,7 +69,12 @@ struct closure_rng { }; /** - * A closure that can be called with `propto` template argument. + * A closure that may compute an unnormalized propability density. + * + * @tparam Propto if true the function is unnormalized + * @tparam Ref if true values are captured by reference + * @tparam F the lambda functor type + * @tparam Ts types of the captured values */ template struct closure_lpdf { @@ -95,6 +110,11 @@ struct closure_lpdf { /** * A closure that accesses logprob accumulator. + * + * @tparam Propto if true the logprob is unnormalized + * @tparam Ref if true values are captured by reference + * @tparam F the lambda functor type + * @tparam Ts types of the captured values */ template struct closure_lp { @@ -114,8 +134,8 @@ struct closure_lp { const Args&... args) const { return apply( [this, &lp, &lp_accum, msgs, &args...](const auto&... s) { - return f_.template operator()(s..., args..., lp, lp_accum, - msgs); + return this->f_.template operator()(s..., args..., lp, + lp_accum, msgs); }, captures_); } diff --git a/test/unit/math/prim/meta/is_stan_closure_test.cpp b/test/unit/math/prim/meta/is_stan_closure_test.cpp new file mode 100644 index 00000000000..7138e31f448 --- /dev/null +++ b/test/unit/math/prim/meta/is_stan_closure_test.cpp @@ -0,0 +1,18 @@ +#include +#include +#include +#include + +TEST(MathMetaPrim, IsStanClosure) { + auto lambda = [](auto msg) { return 0.0; }; + auto cl = stan::math::from_lambda(lambda); + EXPECT_FALSE((stan::is_stan_closure::value)); + EXPECT_TRUE((stan::is_stan_closure::value)); +} + +TEST(MathMetaPrim, ClosureReturnType) { + EXPECT_SAME_TYPE(const std::vector&, + stan::closure_return_type, true>::type); + EXPECT_SAME_TYPE(std::vector, + stan::closure_return_type, false>::type); +}