From 107b44c4dd6af578649af3ae1bd741c71f44e526 Mon Sep 17 00:00:00 2001 From: Ben Margolis Date: Mon, 19 Jun 2023 10:51:04 -0700 Subject: [PATCH 1/3] add rootinfo and fix bug to pass userdata to JacRhs --- scikits/odes/sundials/cvode.pxd | 9 ++++---- scikits/odes/sundials/cvode.pyx | 39 +++++++++++++++++++++++++-------- 2 files changed, 35 insertions(+), 13 deletions(-) diff --git a/scikits/odes/sundials/cvode.pxd b/scikits/odes/sundials/cvode.pxd index b3773bda..54f0fdbd 100644 --- a/scikits/odes/sundials/cvode.pxd +++ b/scikits/odes/sundials/cvode.pxd @@ -23,18 +23,19 @@ cdef class CV_RootFunction: cdef class CV_WrapRootFunction(CV_RootFunction): cdef object _rootfn - cdef int with_userdata + cdef public int with_userdata cpdef set_rootfn(self, object rootfn) cdef class CV_JacRhsFunction: cpdef int evaluate(self, DTYPE_t t, np.ndarray[DTYPE_t, ndim=1] y, np.ndarray[DTYPE_t, ndim=1] fy, - np.ndarray[DTYPE_t, ndim=2] J) except? -1 + np.ndarray[DTYPE_t, ndim=2] J, + object userdata = *) except? -1 cdef class CV_WrapJacRhsFunction(CV_JacRhsFunction): cdef public object _jacfn - cdef int with_userdata + cdef public int with_userdata cpdef set_jacfn(self, object jacfn) cdef class CV_PrecSetupFunction: @@ -128,7 +129,7 @@ cdef class CVODE: cdef N_Vector atol cdef void* _cv_mem cdef SUNContext sunctx - cdef dict options + cdef public dict options cdef bint parallel_implementation, initialized, _old_api, _step_compute, _validate_flags cdef CV_data aux_data diff --git a/scikits/odes/sundials/cvode.pyx b/scikits/odes/sundials/cvode.pyx index 91ce51ca..a8247ab4 100644 --- a/scikits/odes/sundials/cvode.pyx +++ b/scikits/odes/sundials/cvode.pyx @@ -257,7 +257,8 @@ cdef class CV_JacRhsFunction: cpdef int evaluate(self, DTYPE_t t, np.ndarray[DTYPE_t, ndim=1] y, np.ndarray[DTYPE_t, ndim=1] fy, - np.ndarray[DTYPE_t, ndim=2] J) except? -1: + np.ndarray[DTYPE_t, ndim=2] J, + object userdata = None) except? -1: """ Returns the Jacobi matrix of the right hand side function, as d(rhs)/d y @@ -275,22 +276,30 @@ cdef class CV_WrapJacRhsFunction(CV_JacRhsFunction): """ Set some jacobian equations as a JacRhsFunction executable class. """ + self.with_userdata = 0 + self._jacfn = jacfn + nrarg = _get_num_args(jacfn) + if nrarg > 5: + #hopefully a class method, self gives 6 arg! + self.with_userdata = 1 + elif nrarg == 5 and inspect.isfunction(jacfn): + self.with_userdata = 1 self._jacfn = jacfn cpdef int evaluate(self, DTYPE_t t, np.ndarray[DTYPE_t, ndim=1] y, np.ndarray[DTYPE_t, ndim=1] fy, - np.ndarray J) except? -1: + np.ndarray J, + object userdata = None) except? -1: """ Returns the Jacobi matrix (for dense the full matrix, for band only bands. Result has to be stored in the variable J, which is preallocated to the corresponding size. """ -## if self.with_userdata == 1: -## self._jacfn(t, y, ydot, cj, J, userdata) -## else: -## self._jacfn(t, y, ydot, cj, J) - user_flag = self._jacfn(t, y, fy, J) + if self.with_userdata == 1: + user_flag = self._jacfn(t, y, fy, J, userdata) + else: + user_flag = self._jacfn(t, y, fy, J) if user_flag is None: user_flag = 0 @@ -318,7 +327,7 @@ cdef int _jacdense(sunrealtype tt, ff_tmp = aux_data.z_tmp nv_s2ndarray(ff, ff_tmp) - user_flag = aux_data.jac.evaluate(tt, yy_tmp, ff_tmp, jac_tmp) + user_flag = aux_data.jac.evaluate(tt, yy_tmp, ff_tmp, jac_tmp, aux_data.user_data,) if parallel_implementation: raise NotImplemented @@ -1644,8 +1653,12 @@ cdef class CVODE: else: _test = np.empty((len(y0), len(y0)), DTYPE) _fy_test = np.zeros(len(y0), DTYPE) - jac._jacfn(t0, y0, _fy_test, _test) + if jac.with_userdata: + jac._jacfn(t0, y0, _fy_test, _test, opts['user_data']) + else: + jac._jacfn(t0, y0, _fy_test, _test) _test = None + _fy_test = None #now we initialize storage which is persistent over steps self.t_roots = [] @@ -1896,6 +1909,14 @@ cdef class CVODE: self.t_tstop, self.y_tstop, ) + def rootinfo(self): + #cdef int[self.options['nr_rootfns']] rootsfound + N = self.options['nr_rootfns'] + cdef np.ndarray[int, ndim=1, mode='c'] rootsfound = np.empty(N, dtype=np.int32) + #cdef int rootsfound[N] + CVodeGetRootInfo(self._cv_mem, &rootsfound[0]) + return rootsfound + def step(self, DTYPE_t t, np.ndarray[DTYPE_t, ndim=1] y_retn = None): """ From cdfb3bfb12027907fc8ffa81fab48ff3424e33f9 Mon Sep 17 00:00:00 2001 From: Ben Margolis Date: Fri, 21 Jul 2023 19:29:45 -0700 Subject: [PATCH 2/3] allow tstop of any value and setting of max_step_size after solver instantiation --- scikits/odes/sundials/cvode.pyx | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/scikits/odes/sundials/cvode.pyx b/scikits/odes/sundials/cvode.pyx index a8247ab4..a3c1fa63 100644 --- a/scikits/odes/sundials/cvode.pyx +++ b/scikits/odes/sundials/cvode.pyx @@ -1077,7 +1077,7 @@ cdef class CVODE: if not supress_supported_check: for opt in options.keys(): if not opt in ['atol', 'rtol', 'tstop', 'rootfn', 'nr_rootfns', - 'verbosity', 'one_step_compute']: + 'verbosity', 'one_step_compute', 'max_step_size']: raise ValueError("Option '%s' can''t be set runtime." % opt) # Verbosity level @@ -1181,7 +1181,7 @@ cdef class CVODE: if ('tstop' in options) and (options['tstop'] is not None): opts_tstop = options['tstop'] self.options['tstop'] = opts_tstop - if (not opts_tstop is None) and (opts_tstop > 0.): + if (not opts_tstop is None): flag = CVodeSetStopTime(cv_mem, opts_tstop) if flag == CV_ILL_INPUT: raise ValueError('CVodeSetStopTime::Stop value is beyond ' From d2289c4d3efd03a83285549eb3ff4599e1a9681a Mon Sep 17 00:00:00 2001 From: Benjamin Margolis Date: Fri, 31 Jan 2025 15:38:47 -0800 Subject: [PATCH 3/3] fixup? --- docs/examples/ode/class_based.py | 113 ++++++++++++++++++++++ docs/examples/ode/cvodes_test.py | 81 ++++++++++++++++ docs/examples/ode/feature_test.py | 106 ++++++++++++++++++++ docs/examples/ode/simpleoscillator.py | 4 +- docs/examples/ode/simpleoscillator_jac.py | 54 ++++++++++- docs/examples/planarpendulum.py | 2 + scikits/odes/sundials/cvode.pyx | 20 ++-- scikits/odes/sundials/cvodes.pxd | 5 + scikits/odes/sundials/cvodes.pyx | 16 ++- setup.py | 1 + 10 files changed, 382 insertions(+), 20 deletions(-) create mode 100644 docs/examples/ode/class_based.py create mode 100644 docs/examples/ode/cvodes_test.py create mode 100644 docs/examples/ode/feature_test.py diff --git a/docs/examples/ode/class_based.py b/docs/examples/ode/class_based.py new file mode 100644 index 00000000..274c2671 --- /dev/null +++ b/docs/examples/ode/class_based.py @@ -0,0 +1,113 @@ +# Authors: B. Malengier +""" +This example shows the most simple way of using a solver. +We solve free vibration of a simple oscillator:: + m \ddot{u} + k u = 0, u(0) = u_0, \dot{u}(0) = \dot{u}_0 +using the CVODE solver, which means we use a rhs function of \dot{u}. +Solution:: + u(t) = u_0*cos(sqrt(k/m)*t)+\dot{u}_0*sin(sqrt(k/m)*t)/sqrt(k/m) + +""" +from __future__ import print_function +from numpy import asarray, cos, sin, sqrt +import numpy as np +from scikits.odes.sundials.cvode import CVODE, StatusEnum, CV_WrapJacRhsFunction +from collections import namedtuple + +#data +k = 4.0 +m = 1.0 +t1 = 10. +#initial data on t=0, x[0] = u, x[1] = \dot{u}, xp = \dot{x} +initx = [1, 0.1] + +def rhseqn(t, x): + """ we create rhs equations for the problem""" + return [ + x[1], + - k/m * x[0] + ] + +def jaceqn(t, x,): + jac = np.zeros((2,2)) + jac[0,1] = 1 + jac[1,0] = -k/m + return jac + + +def rootfn(t, x): + return ( + x[0], + x[1], + t - t1, + np.sin(t), + ) + +Root = namedtuple("Root", ["index", "rootsfound"]) +Results = namedtuple("Results", ["t", "x", "e"],) + +class System: + def __init__(self, dots, jac, events, num_events): + self._dots = dots + self._jac = jac + self._events = events + self.num_events = num_events + self.solver = CVODE( + self.dots, + jacfn=self.jac, + old_api=False, + one_step_compute=True, + rootfn=self.events, + nr_rootfns=self.num_events, + ) + + def dots(self, t, x, xdot,):# userdata=None,): + xdot[:] = self._dots(t, x) + + def jac(self, t, x, xdot, jac, userdata=None,): + jac[...] = self._jac(t, x) + + def events(self, t, x, g): + g[:] = self._events(t, x) + + def simulate(self, t0, x0, tf, results=None): + + if results is None: + results = Results([],[],[]) + + dense_t = results.t + dense_y = results.x + roots = results.e + + solver = self.solver + + solver.init_step(t0, x0) + solver.set_options(tstop=tf) + + dense_t.append(np.copy(t0)) + dense_y.append(np.copy(x0)) + + for cnt in range(1000): + res = solver.step(t1) + print(cnt, res.flag, res.values.t) + dense_t.append(np.copy(res.values.t)) + dense_y.append(np.copy(res.values.y)) + match res.flag: + case StatusEnum.ROOT_RETURN: + rootsfound = solver.rootinfo() + roots.append(Root(cnt, rootsfound)) + + if res.values.t == tf: + break + + case StatusEnum.TSTOP_RETURN: + #continue + break + + return results + + + +sys = System(rhseqn, jaceqn, rootfn, 4) +res1 = sys.simulate(0., initx, 11.) +res2 = sys.simulate(0., initx, 10.) diff --git a/docs/examples/ode/cvodes_test.py b/docs/examples/ode/cvodes_test.py new file mode 100644 index 00000000..adf6175d --- /dev/null +++ b/docs/examples/ode/cvodes_test.py @@ -0,0 +1,81 @@ +# Authors: B. Malengier +""" +This example shows the most simple way of using a solver. +We solve free vibration of a simple oscillator:: + m \ddot{u} + k u = 0, u(0) = u_0, \dot{u}(0) = \dot{u}_0 +using the CVODE solver, which means we use a rhs function of \dot{u}. +Solution:: + u(t) = u_0*cos(sqrt(k/m)*t)+\dot{u}_0*sin(sqrt(k/m)*t)/sqrt(k/m) + +""" +from __future__ import print_function +from numpy import asarray, cos, sin, sqrt +import numpy as np + +#data +userdata = dict( +k = 4.0, +m = 1.0, +t1 = 10., +rhs_calls = 0, +jac_calls = 0, +) +#initial data on t=0, x[0] = u, x[1] = \dot{u}, xp = \dot{x} +initx = [1, 0.1] + +#define function for the right-hand-side equations which has specific signature +def rhseqn(t, x, xdot, my_user_data): + """ we create rhs equations for the problem""" + + k = my_user_data['k'] + m = my_user_data['m'] + my_user_data['rhs_calls'] += 1 + xdot[0] = x[1] + xdot[1] = - k/m * x[0] + +def jaceqn(t, x, fx, jac, my_user_data):#=None): + my_user_data['jac_calls'] += 1 + if my_user_data is None: + print("ERROR") + return 0 + + k = my_user_data['k'] + m = my_user_data['m'] + + jac[0,1] = 1 + jac[1,0] = -k/m + +#instantiate the solver +if False: + from scikits.odes.sundials.cvode import CVODE, StatusEnum, CV_WrapJacRhsFunction + SolverClass = CVODE +else: + from scikits.odes.sundials.cvode import CV_WrapJacRhsFunction + from scikits.odes.sundials.cvodes import CVODES, StatusEnum + SolverClass = CVODES + +from collections import namedtuple + +def rootfn(t, x, g, my_user_data): + t1 = my_user_data['t1'] + g[0] = x[0] + g[1] = x[1] + g[2] = t - t1 + +solver = SolverClass( + rhseqn, user_data=userdata,# jacfn=jaceqn, + old_api=False, one_step_compute=True, + rootfn=rootfn, nr_rootfns=3, ) + +next_tstop = 10. +#solver.init_step(0., initx) +solver.set_options(tstop = next_tstop) +dense_t = [] +dense_y = [] +roots = [] +Root = namedtuple("Root", ["index", "rootsfound"]) +print("starting loop") +#print(solver.get_info(),) +res = solver.solve([0, 10.0], initx) +print("completed:",res) +#print(solver.get_info(), solver.num_chk_pts, solver.options["rfn"]) diff --git a/docs/examples/ode/feature_test.py b/docs/examples/ode/feature_test.py new file mode 100644 index 00000000..1f0a00b9 --- /dev/null +++ b/docs/examples/ode/feature_test.py @@ -0,0 +1,106 @@ +# Authors: B. Malengier +""" +This example shows the most simple way of using a solver. +We solve free vibration of a simple oscillator:: + m \ddot{u} + k u = 0, u(0) = u_0, \dot{u}(0) = \dot{u}_0 +using the CVODE solver, which means we use a rhs function of \dot{u}. +Solution:: + u(t) = u_0*cos(sqrt(k/m)*t)+\dot{u}_0*sin(sqrt(k/m)*t)/sqrt(k/m) + +""" +from __future__ import print_function +from numpy import asarray, cos, sin, sqrt +import numpy as np + +#data +userdata = dict( +k = 4.0, +m = 1.0, +t1 = 10., +rhs_calls = 0, +jac_calls = 0, +) +#initial data on t=0, x[0] = u, x[1] = \dot{u}, xp = \dot{x} +initx = [1, 0.1] + +#define function for the right-hand-side equations which has specific signature +def rhseqn(t, x, xdot, my_user_data): + """ we create rhs equations for the problem""" + + k = my_user_data['k'] + m = my_user_data['m'] + my_user_data['rhs_calls'] += 1 + xdot[0] = x[1] + xdot[1] = - k/m * x[0] + +def jaceqn(t, x, fx, jac, my_user_data):#=None): + my_user_data['jac_calls'] += 1 + if my_user_data is None: + print("ERROR") + return 0 + + k = my_user_data['k'] + m = my_user_data['m'] + + jac[0,1] = 1 + jac[1,0] = -k/m + +#instantiate the solver +#from scikits.odes.sundials.cvode import CVODE, StatusEnum, CV_WrapJacRhsFunction +#SolverClass = CVODE +from scikits.odes.sundials.cvode import CV_WrapJacRhsFunction +from scikits.odes.sundials.cvodes import CVODES, StatusEnum +SolverClass = CVODES +from collections import namedtuple + +def rootfn(t, x, g, my_user_data): + t1 = my_user_data['t1'] + g[0] = x[0] + g[1] = x[1] + g[2] = t - t1 + +solver = SolverClass( + rhseqn, user_data=userdata,# jacfn=jaceqn, + old_api=False, one_step_compute=True, + rootfn=rootfn, nr_rootfns=3, ) + +next_tstop = 10. +solver.init_step(0., initx) +solver.set_options(tstop = next_tstop) +dense_t = [] +dense_y = [] +roots = [] +Root = namedtuple("Root", ["index", "rootsfound"]) +print("starting loop") +#print(solver.get_info(), solver.num_chk_pts, solver.options["rfn"]) +res = solver.solve([0, 10.0], initx) + +for cnt in range(1000): + #res = solver.step(1.) + s + print(cnt, res.flag, res.values.t) + dense_t.append(np.copy(res.values.t)) + dense_y.append(np.copy(res.values.y)) + match res.flag: + case StatusEnum.ROOT_RETURN: + rootsfound = solver.rootinfo() + roots.append(Root(cnt, rootsfound)) + + case StatusEnum.TSTOP_RETURN: + #continue + break + + + if res.values.t > 10.01: + print("broke from t") + break + #print(cnt, res) + #if res.values.y[0] <= 0.: + # break +t = np.array(dense_t) +y = np.array(dense_y) +print(solver.get_info()) +print(userdata) +print(cnt) +print(solver.num_chk_pts) + diff --git a/docs/examples/ode/simpleoscillator.py b/docs/examples/ode/simpleoscillator.py index db12bfbd..ea49a34b 100644 --- a/docs/examples/ode/simpleoscillator.py +++ b/docs/examples/ode/simpleoscillator.py @@ -25,7 +25,7 @@ def rhseqn(t, x, xdot): #instantiate the solver from scikits.odes import ode -solver = ode('cvode', rhseqn, old_api=True) +solver = ode('cvode', rhseqn, old_api=True, rtol=1E-12, atol=1E-15) #obtain solution at a required time result = solver.solve([0., 10., 20.], initx) @@ -84,4 +84,4 @@ def scrhseqn(t, x): print('%4.2f %15.6g %15.6g' % (solver.t, solver.y[0], initx[0]*cos(sqrt(k/m)*solver.t)+initx[1]*sin(sqrt(k/m)*solver.t)/sqrt(k/m))) solver.integrate(solver.t+100) -print('%4.2f %15.6g %15.6g' % (solver.t, solver.y[0], initx[0]*cos(sqrt(k/m)*solver.t)+initx[1]*sin(sqrt(k/m)*solver.t)/sqrt(k/m))) \ No newline at end of file +print('%4.2f %15.6g %15.6g' % (solver.t, solver.y[0], initx[0]*cos(sqrt(k/m)*solver.t)+initx[1]*sin(sqrt(k/m)*solver.t)/sqrt(k/m))) diff --git a/docs/examples/ode/simpleoscillator_jac.py b/docs/examples/ode/simpleoscillator_jac.py index 28ba7c37..5aac4113 100644 --- a/docs/examples/ode/simpleoscillator_jac.py +++ b/docs/examples/ode/simpleoscillator_jac.py @@ -10,6 +10,7 @@ """ from __future__ import print_function from numpy import asarray, cos, sin, sqrt +import numpy as np #data k = 4.0 @@ -29,20 +30,63 @@ def jaceqn(t, x, fx, jac): #instantiate the solver from scikits.odes import ode -solver = ode('cvode', rhseqn, jacfn=jaceqn) +solver = ode('cvode', rhseqn, jacfn=jaceqn, ) #obtain solution at a required time -result = solver.solve([0., 1., 2.], initx) +result = solver.solve([0., 10., 20.], initx) print('\n t Solution Exact') print('------------------------------------') -for t, u in zip(result[1], result[2]): +#for t, u in zip(result[1], result[2]): +for t, u in zip(result.values.t, result.values.y): print('%4.2f %15.6g %15.6g' % (t, u[0], initx[0]*cos(sqrt(k/m)*t)+initx[1]*sin(sqrt(k/m)*t)/sqrt(k/m))) #continue the solver -result = solver.solve([result[1][-1], result[1][-1]+1], result[2][-1]) +#result = solver.solve([result[1][-1], result[1][-1]+1], result[2][-1]) +result = solver.solve([result.values.t[-1], result.values.t[-1]+1, result.values.t[-1]+110], result.values.y[-1]) print('------------------------------------') print(' ...continuation of the solution') print('------------------------------------') -for t, u in zip(result[1], result[2]): +#for t, u in zip(result[1], result[2]): +for t, u in zip(result.values.t, result.values.y): print ('%4.2f %15.6g %15.6g' % (t, u[0], initx[0]*cos(sqrt(k/m)*t)+initx[1]*sin(sqrt(k/m)*t)/sqrt(k/m))) + + +from scikits.odes.sundials.cvode import CVODE, StatusEnum +from collections import namedtuple + +def rootfn(t, x, g): + g[0] = x[0] + g[1] = x[1] + g[2] = t - 10. +solver = CVODE(rhseqn, jacfn=jaceqn, old_api=False, one_step_compute=True, + rootfn=rootfn, nr_rootfns=3) + +next_tstop = 10. +solver.init_step(0., initx) +solver.set_options(tstop = next_tstop) +dense_t = [] +dense_y = [] +roots = [] +Root = namedtuple("Root", ["index", "rootsfound"]) +for cnt in range(1000): + res = solver.step(1.) + print(cnt, res.flag, res.values.t) + dense_t.append(np.copy(res.values.t)) + dense_y.append(np.copy(res.values.y)) + match res.flag: + case StatusEnum.ROOT_RETURN: + rootsfound = solver.rootinfo() + roots.append(Root(cnt, rootsfound)) + + case StatusEnum.TSTOP_RETURN: + break +t = np.array(dense_t) +y = np.array(dense_y) + + + #if res.values.t > 9.99: + # break + #print(cnt, res) + #if res.values.y[0] <= 0.: + # break diff --git a/docs/examples/planarpendulum.py b/docs/examples/planarpendulum.py index d7c7e072..37190e2b 100644 --- a/docs/examples/planarpendulum.py +++ b/docs/examples/planarpendulum.py @@ -47,6 +47,8 @@ from scikits.odes.sundials import ida import matplotlib.pyplot as plt +np.float = float + def draw_graphs(fignum, t, x, y): plt.ion() plt.figure(fignum) diff --git a/scikits/odes/sundials/cvode.pyx b/scikits/odes/sundials/cvode.pyx index a3c1fa63..9e737036 100644 --- a/scikits/odes/sundials/cvode.pyx +++ b/scikits/odes/sundials/cvode.pyx @@ -5,6 +5,8 @@ from enum import IntEnum import inspect from warnings import warn +# Use python setup.py intall to install + include "sundials_config.pxi" import numpy as np @@ -1971,15 +1973,15 @@ cdef class CVODE: t_err = None y_err = None sol_t_out = t_out - if flagCV == CV_SUCCESS or flag == CV_WARNING: - pass - elif flagCV == CV_ROOT_RETURN: - self.t_roots.append(np.copy(t_out)) - self.y_roots.append(np.copy(y_out)) - elif flagCV == CV_TSTOP_RETURN: - self.t_tstop.append(np.copy(t_out)) - self.y_tstop.append(np.copy(y_out)) - elif flagCV < 0: + #if flagCV == CV_SUCCESS or flag == CV_WARNING: + # pass + #elif flagCV == CV_ROOT_RETURN: + # self.t_roots.append(np.copy(t_out)) + # self.y_roots.append(np.copy(y_out)) + #elif flagCV == CV_TSTOP_RETURN: + # self.t_tstop.append(np.copy(t_out)) + # self.y_tstop.append(np.copy(y_out)) + if flagCV < 0: t_err = np.copy(t_out) y_err = np.copy(y_out) sol_t_out = None diff --git a/scikits/odes/sundials/cvodes.pxd b/scikits/odes/sundials/cvodes.pxd index 37ff740c..3333c654 100644 --- a/scikits/odes/sundials/cvodes.pxd +++ b/scikits/odes/sundials/cvodes.pxd @@ -20,5 +20,10 @@ cdef class CVS_data(CV_data): cdef class CVODES(CVODE): cdef N_Vector aStol cdef CVS_data aux_dataS + cdef public int num_chk_pts cdef int Ns #sensitivity parameter size + cpdef _init_adjoint_step(self) + cpdef _solve(self, np.ndarray[DTYPE_t, ndim=1] tspan, + np.ndarray[DTYPE_t, ndim=1] y0) + diff --git a/scikits/odes/sundials/cvodes.pyx b/scikits/odes/sundials/cvodes.pyx index 69905d26..06346ffa 100644 --- a/scikits/odes/sundials/cvodes.pyx +++ b/scikits/odes/sundials/cvodes.pyx @@ -29,7 +29,7 @@ from .cvode cimport (CV_RhsFunction, CV_WrapRhsFunction, CV_RootFunction, CV_JacTimesVecFunction, CV_WrapJacTimesVecFunction, CV_JacTimesSetupFunction, CV_WrapJacTimesSetupFunction, CV_ContinuationFunction, CV_ErrHandler, - CV_WrapErrHandler, CV_data, CVODE) + CV_WrapErrHandler, CV_data, CVODE,) from .c_cvodes cimport * from .common_defs cimport ( nv_s2ndarray, ndarray2nv_s, ndarray2SUNMatrix, DTYPE_t, INDEX_TYPE_t, @@ -162,7 +162,7 @@ cdef class CVS_data(CV_data): cdef class CVODES(CVODE): - def __cinit__(self, Rfn, **options): + def __cinit__(self, Rfn, num_steps_per_chk=1, interp_type=CV_POLYNOMIAL, **options): """ Initialize the CVODE Solver and it's default values @@ -172,9 +172,11 @@ cdef class CVODES(CVODE): of supported options and their values see set_options() """ - super(CVODES, self).__init__(Rfn, **options) - + self.num_chk_pts = 0 self.Ns = -1 + super(CVODES, self).__init__(Rfn, **options) + self.options["num_steps_per_chk"] = num_steps_per_chk + self.options["interp_type"] = interp_type def set_options(self, **options): """ @@ -453,8 +455,14 @@ cdef class CVODES(CVODE): """ soln = super(CVODES, self).init_step(t0, y0) + #self._init_adjoint_step() + return soln + cpdef _init_adjoint_step(self):#, long int steps, int interp): + cdef void* cv_mem = self._cv_mem + CVodeAdjInit(cv_mem, self.options["num_steps_per_chk"], self.options["interp_type"]) + def solve(self, object tspan, object y0): """ Runs the solver. diff --git a/setup.py b/setup.py index 1c963457..dfef1493 100644 --- a/setup.py +++ b/setup.py @@ -60,6 +60,7 @@ # If any package contains *.pxd files, include them: '': ['*.pxd'], }, + gdb_debug=True, classifiers = CLASSIFIERS, **additional_kwargs )