Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions scikits/odes/sundials/cvode.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,8 @@ cdef class CV_data:
cdef object user_data
cdef CV_ErrHandler err_handler
cdef object err_user_data
cdef void* cv_mem
cdef object jac_viewer

cdef class CVODE:
cdef N_Vector atol
Expand Down
56 changes: 55 additions & 1 deletion scikits/odes/sundials/cvode.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,49 @@ cdef int _jacdense(long int Neq, realtype tt,

return user_flag

cdef extern int cvDlsDenseDQJac(long int Neq, realtype tt,
N_Vector yy, N_Vector ff, DlsMat Jac,
void *auxiliary_data, N_Vector tmp1, N_Vector tmp2, N_Vector tmp3)

cdef int _jac_viewer_dense(long int Neq, realtype tt,
N_Vector yy, N_Vector ff, DlsMat Jac,
void *auxiliary_data, N_Vector tmp1, N_Vector tmp2, N_Vector tmp3) except? -1:
"""
function with the signature of CVDlsDenseJacFn that calls user/builtin jac
function and calls jac_viewer on the result
"""
cdef np.ndarray[DTYPE_t, ndim=1] yy_tmp, yp_tmp
cdef np.ndarray jac_tmp
cdef int jac_flag

aux_data = <CV_data> auxiliary_data
cdef bint parallel_implementation = aux_data.parallel_implementation
if parallel_implementation:
raise NotImplemented
else:
if aux_data.jac:
jac_flag = _jacdense(
Neq, tt, yy, ff, Jac, auxiliary_data, tmp1, tmp2, tmp3
)
else:
jac_flag = cvDlsDenseDQJac(
Neq, tt, yy, ff, Jac, aux_data.cv_mem, tmp1, tmp2, tmp3
)
yy_tmp = aux_data.yy_tmp
yp_tmp = aux_data.yp_tmp
if aux_data.jac_tmp is None:
N = np.alen(yy_tmp)
aux_data.jac_tmp = np.empty((N,N), float)
jac_tmp = aux_data.jac_tmp

nv_s2ndarray(yy, yy_tmp)
nv_s2ndarray(ff, yp_tmp)
user_flag = aux_data.jac_viewer(
"dense", jac_flag, tt, yy_tmp, yp_tmp, jac_tmp, aux_data,
)

return user_flag

# Precondioner setup funtion
cdef class CV_PrecSetupFunction:
"""
Expand Down Expand Up @@ -638,6 +681,7 @@ cdef class CV_data:
self.g_tmp = None
self.r_tmp = None
self.z_tmp = None
self.jac_viewer = None

cdef class CVODE:

Expand Down Expand Up @@ -686,6 +730,7 @@ cdef class CVODE:
'onroot': None,
'ontstop': None,
'validate_flags': None,
'jac_viewer': None,
}

self.verbosity = 1
Expand Down Expand Up @@ -919,6 +964,10 @@ cdef class CVODE:
Controls whether to validate flags as a result of calling
`solve`. See the `validate_flags` function for how this
affects `solve`.
'jac_viewer':
Description:
Function which examines the jacobian produced by cvode or by
the user.
"""

# Update values of all supplied options
Expand Down Expand Up @@ -1226,6 +1275,7 @@ cdef class CVODE:

# Initialize auxiliary variables
self.aux_data = CV_data(N)
self.aux_data.cv_mem = self._cv_mem

# Set err_handler
err_handler = opts.get('err_handler', None)
Expand Down Expand Up @@ -1267,6 +1317,8 @@ cdef class CVODE:
opts['jacfn'] = tmpfun
self.aux_data.jac = jac

self.aux_data.jac_viewer = opts['jac_viewer']

#we test if rfn call doesn't give errors due to bad coding, as
#cvode will ignore errors, it only checks return value (0 or 1 for error)
if isinstance(rfn, CV_WrapRhsFunction):
Expand Down Expand Up @@ -1433,7 +1485,9 @@ cdef class CVODE:
raise ValueError('LinSolver: Unknown solver type: %s'
% opts['linsolver'])

if (linsolver in ['dense', 'lapackdense']) and self.aux_data.jac:
if (linsolver in ['dense', 'lapackdense']) and self.aux_data.jac_viewer:
CVDlsSetDenseJacFn(cv_mem, _jac_viewer_dense)
elif (linsolver in ['dense', 'lapackdense']) and self.aux_data.jac:
CVDlsSetDenseJacFn(cv_mem, _jacdense)

#we test if jac don't give errors due to bad coding, as
Expand Down