Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ classifiers = [
"Environment :: Console",
]
dependencies = [
"jitcdde",
"jitcdde>1.8.1",
"numpy",
"chspy"
]
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
jitcdde
git+https://github.com/neurophysik/jitcdde.git@master
numpy
chspy
pytest
Expand Down
276 changes: 183 additions & 93 deletions src/msrDynamics/_msrDynamics.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
import matplotlib.pyplot as plt
from tqdm import tqdm
from symengine import Mul
import sys

MAX_INT = sys.maxsize

class System:
"""
Expand Down Expand Up @@ -174,7 +177,7 @@ def set_custom_past(self, past: chspy.CubicHermiteSpline, t_truncate=None):
past.truncate(t_truncate)
self.custom_past = past

def finalize(self, max_delay):
def finalize(self):
"""
Instantiate and store JiTCDDE integrator.

Expand All @@ -190,9 +193,9 @@ def finalize(self, max_delay):

# input uses different integrator object
if self.input:
DDE = jitcdde_input(self.dydt, self.input, max_delay = max_delay)
DDE = jitcdde_input(self.dydt, self.input, max_delay = self.max_delay)
else:
DDE = jitcdde(self.dydt, max_delay = max_delay)
DDE = jitcdde(self.dydt, max_delay = self.max_delay)

# populate callback functions
if self.pid_loops:
Expand Down Expand Up @@ -254,6 +257,26 @@ def get_node_by_index(self, idx):
return n
raise ValueError(f'Node with index {idx} not found')

def _prepare_integrator(self,
abs_tol=1e-10,
rel_tol=1e-05,
min_step = 1e-10,
max_step = 10.0,
):

if len(self.nodes) == 0:
raise ValueError('No nodes have been added to the system')

# clear data
for n in self.nodes.values():
if (n.y_out.any()):
n.y_out = []

# set integrator
print("finalizing integrator...")
self.finalize()
self.integrator.set_integration_parameters(atol=abs_tol, rtol=rel_tol, min_step = min_step, max_step = max_step)

def solve(self,
T,
max_delay=1e10,
Expand Down Expand Up @@ -282,97 +305,162 @@ def solve(self,
Returns:
np.ndarray: Solution matrix.
"""
if len(self.nodes) == 0:
raise ValueError('No nodes have been added to the system')

# clear data
for n in self.nodes.values():
if (n.y_out.any()):
n.y_out = []

# set integrator
print("finalizing integrator...")
self.max_delay = max_delay
self.finalize(max_delay = max_delay)
self.integrator.set_integration_parameters(atol=abs_tol, rtol=rel_tol, min_step = min_step, max_step = max_step)
# solution
y = []
self._prepare_integrator(abs_tol, rel_tol, min_step, max_step)

print("integrating...")
# solve
if self.trip_conditions:
y = self._solve_with_trip_conditions(T, md_step)
else:
y = self._solve_default(T, md_step)

# integrate with trip conditions
for t_x in T:
# extract state and derivs for trip check
if (t_x < max_delay):
y.append(np.array(self.integrator.integrate_blindly(t_x, step = md_step)))
else:
y.append(np.array(self.integrator.integrate(t_x)))
idxs = [c.idx for c in self.trip_conditions]
states = y[-1][idxs]
# populate node objects with solutions, off by default, as it can cause
# memory blowup/leakage when running many models
if populate_nodes:
self._populate_nodes(y)

# derivative is only estimated after the first step
if len(y) > 1:
derivs = ((y[-1]-y[-2])/(T[1]-T[0]))[idxs]
else:
derivs = [0.0]*len(states)

if 'state' in self.trip_info:
self.trip_info['state'].extend([chspy.Anchor(t_x, states, derivs)])
return np.array(y)

def _solve_default(self, times, md_step):
y = []
with tqdm(total=len(times), desc="Integration progress") as pbar:
for t_x in times[times<=self.max_delay]:
y.append(self.integrator.integrate_blindly(t_x, step = md_step))
pbar.update(1)
for t_x in times[times>self.max_delay]:
y.append(self.integrator.integrate(t_x))
pbar.update(1)
return np.array(y)

def _solve_with_trip_conditions(self, times, md_step):
y = []
# integrate with trip conditions
for t_idx, t_x in enumerate(times):
# extract state and derivs for trip check
if (t_x <= self.max_delay):
y.append(np.array(self.integrator.integrate_blindly(t_x, step = md_step)))
else:
y.append(np.array(self.integrator.integrate(t_x)))
idxs = [c.idx for c in self.trip_conditions]
states = y[-1][idxs]

# derivative is only estimated after the first step
if len(y) > 1:
derivs = ((y[-1]-y[-2])/(times[t_idx]-times[t_idx-1]))[idxs]
else:
derivs = [0.0]*len(states)

if 'state' in self.trip_info:
self.trip_info['state'].extend([chspy.Anchor(t_x, states, derivs)])
else:
self.trip_info['state'] = chspy.CubicHermiteSpline(n=len(self.trip_conditions),
anchors=[chspy.Anchor(t_x, states, derivs)])

# check if system has tripped
tripped = self._check_trip(t_x, states, derivs)
if tripped:
# get trip condition object
trip_obj = self.trip_conditions[tripped[0]]
print(f'idx {tripped[0]} tripped after integration to t = {t_x:3f} with a value of {tripped[1]}')

# store trip info
self.trip_info['tripped'] = True
self.trip_info['idx'] = trip_obj.idx
self.trip_info['limit'] = tripped[1]
self.trip_info['type'] = trip_obj.trip_type

# get system spline
print('getting state...')
state = self.integrator.get_state()

# calculate exact trip time using splines
print('computing trip time within interval...')
trip_sol = []
start = trip_obj.check_after if trip_obj.check_after is not None else state[0].time
solve_diff = True if self.trip_info['type'] == 'diff' else False
trip_sol = state.solve(self.trip_info['idx'],
self.trip_info['limit'],
beginning=start,
solve_derivative = solve_diff)
if trip_obj.delay:
self.trip_info['time'] = trip_sol[0][0] + trip_obj.delay
else:
self.trip_info['state'] = chspy.CubicHermiteSpline(n=len(self.trip_conditions),
anchors=[chspy.Anchor(t_x, states, derivs)])

# check if system has tripped
tripped = self._check_trip(t_x, states, derivs)
if tripped:
# get trip condition object
trip_obj = self.trip_conditions[tripped[0]]
print(f'idx {tripped[0]} tripped after integration to t = {t_x:3f} with a value of {tripped[1]}')

# store trip info
self.trip_info['tripped'] = True
self.trip_info['idx'] = trip_obj.idx
self.trip_info['limit'] = tripped[1]
self.trip_info['type'] = trip_obj.trip_type

# get system spline
print('getting state...')
state = self.integrator.get_state()

# calculate exact trip time using splines
print('computing trip time within interval...')
trip_sol = []
start = trip_obj.check_after if trip_obj.check_after is not None else state[0].time
solve_diff = True if self.trip_info['type'] == 'diff' else False
trip_sol = state.solve(self.trip_info['idx'],
self.trip_info['limit'],
beginning=start,
solve_derivative = solve_diff)
if trip_obj.delay:
self.trip_info['time'] = trip_sol[0][0] + trip_obj.delay
else:
self.trip_info['time'] = trip_sol[0][0]
print(f"tripped at t = {self.trip_info['time']:.3f}")
print(f"state idx: {tripped[0]}")
print(f"limit: {tripped[1]}")
break
else:
with tqdm(total=len(T), desc="Integration progress") as pbar:
for t_x in T[T<=max_delay]:
y.append(self.integrator.integrate_blindly(t_x, step = md_step))
pbar.update(1)
for t_x in T[T>max_delay]:
y.append(self.integrator.integrate(t_x))
pbar.update(1)

# populate node objects with solutions
self.trip_info['time'] = trip_sol[0][0]
print(f"tripped at t = {self.trip_info['time']:.3f}")
print(f"state idx: {tripped[0]}")
print(f"limit: {tripped[1]}")
break
return np.array(y)

def equilibrium_search(self,
dT,
max_delay=1e10,
populate_nodes=False,
abs_tol=1e-10,
rel_tol=1e-05,
min_step = 1e-10,
max_step = 10.0,
md_step = 1e-3,
abs_tol_eq = 1e-6,
rel_tol_eq = 1e-4,
max_iter = MAX_INT,
norm = None,
show_conv_metrics = False
):
"""
Solves until equilibrium condition reached
"""
self.max_delay = max_delay
self._prepare_integrator(abs_tol, rel_tol, min_step, max_step)

if self.trip_conditions:
raise ValueError('equilibrium_search not compatible with trip conditions')

T = []
y = []
y0 = np.array([self.nodes[n].y0 for n in self.nodes])

diff = float('inf')
tol = abs_tol_eq + rel_tol_eq*np.linalg.norm(y0, ord = norm)
iters = 0
while (diff >= tol) and (iters < max_iter):
# find time
if len(T) == 0:
t_x = dT
else:
t_x = T[-1] + dT
T.append(t_x)

# calculate state
if (t_x <= self.max_delay):
y.append(self.integrator.integrate_blindly(t_x, step = md_step))
else:
y.append(self.integrator.integrate(t_x))

# update error & tolerance
if len(y) == 1:
diff = np.linalg.norm(y[-1]-y0, ord = norm)
else:
diff = np.linalg.norm(y[-1]-y[-2], ord = norm)
tol = abs_tol_eq + rel_tol_eq*np.linalg.norm(y[-1], ord = norm)
iters += 1

if show_conv_metrics:
print(f"converged at t = {T[-1]} after {iters} iterations at tol = {tol}")
print(f"||y_k - y_{{k-1}}||_2 = {diff}")

# populate node objects with solutions, off by default, as it can cause
# memory blowup/leakage when running many models
if populate_nodes:
print('populating nodes objects solution vectors...')
for s in enumerate(self.nodes.values()):
s[1].y_out = np.array([state[s[0]] for state in y])
self._populate_nodes(y)

return np.array(y)
return T, np.array(y)

def _populate_nodes(self, sol):
print('populating nodes objects solution vectors...')
for s in enumerate(self.nodes.values()):
s[1].y_out = np.array([state[s[0]] for state in sol])

def plot_input(self, index, fac = 1.0):
"""
Expand Down Expand Up @@ -459,14 +547,16 @@ def dydt(self):

@dydt.setter
def dydt(self, custom_dydt):
if isinstance(custom_dydt, Mul):
print(
""" Warning: You are setting this node's dynamics equal to that of
another node. If the other node's dynamics are updated, it will
not be propogated to this node. If you wish for updates to be
carried to this node, use Node.set_dydt_node() instead.
"""
)
# TODO: This error message shows anytime custom dydt is set, not just
# when setting equal to another node. Rethink check.
# if isinstance(custom_dydt, Mul):
# print(
# """ Warning: You are setting this node's dynamics equal to that of
# another node. If the other node's dynamics are updated, it will
# not be propogated to this node. If you wish for updates to be
# carried to this node, use Node.set_dydt_node() instead.
# """
# )
self._dydt = custom_dydt

def set_dTdt_advective(self, source):
Expand Down