Skip to content

Commit bb9de21

Browse files
authored
Merge pull request #5 from artificial-life-lab/intergrate
harmonization of simulation solver method imports
2 parents c59fd0e + d021afd commit bb9de21

File tree

7 files changed

+94
-66
lines changed

7 files changed

+94
-66
lines changed

causal_inference/base/lv_simulator.py

Lines changed: 10 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -6,37 +6,12 @@
66
import logging
77
import datetime
88

9-
import h5py
10-
import matplotlib.pyplot as plt
11-
129
from causal_inference.config import RESULTS_DIR
1310
from causal_inference.utils.log_config import log_LV_params
1411
from causal_inference.base.ode_solver import ODE_solver
1512
from causal_inference.base.runge_kutta_solver import RungeKuttaSolver
16-
17-
def _save_population(prey_list, predator_list):
18-
filename = os.path.join(RESULTS_DIR, 'populations.h5')
19-
hf = h5py.File(filename, 'w')
20-
hf.create_dataset('prey_pop', data=prey_list)
21-
hf.create_dataset('pred_pop', data=predator_list)
22-
hf.close()
23-
24-
def plot_population_over_time(prey_list, predator_list, save=True, filename='predator_prey'):
25-
fig = plt.figure(figsize=(15, 5))
26-
ax = fig.add_subplot(2, 1, 1)
27-
PreyLine, = plt.plot(prey_list , color='g')
28-
PredatorsLine, = plt.plot(predator_list, color='r')
29-
ax.set_xscale('log')
30-
31-
plt.legend([PreyLine, PredatorsLine], ['Prey', 'Predators'])
32-
plt.ylabel('Population')
33-
plt.xlabel('Time')
34-
if save:
35-
plt.savefig(os.path.join(RESULTS_DIR, f"{filename}.svg"),
36-
format='svg', transparent=False, bbox_inches='tight')
37-
else:
38-
plt.show()
39-
plt.close()
13+
from causal_inference.utils.writer import _save_population
14+
from causal_inference.utils.visualisations import plot_population_over_time
4015

4116
def get_solver(method):
4217
'''
@@ -50,15 +25,15 @@ def get_solver(method):
5025
raise AssertionError(f'{method} is not implemented!')
5126
return solver
5227

53-
def main(method):
28+
def main(method, results_dir):
5429
'''
5530
Main function that solves LV system.
5631
'''
5732
log_LV_params()
5833
solver = get_solver(method)
5934
prey_list, predator_list = solver._solve()
60-
_save_population(prey_list, predator_list)
61-
plot_population_over_time(prey_list, predator_list)
35+
_save_population(prey_list, predator_list, solver.time_stamps, results_dir)
36+
plot_population_over_time(prey_list, predator_list, solver.time_stamps, results_dir)
6237

6338
if __name__ == '__main__':
6439
PARSER = argparse.ArgumentParser()
@@ -68,12 +43,12 @@ def main(method):
6843
choices=['RK4', 'ODE'], default='RK4')
6944
ARGS = PARSER.parse_args()
7045

71-
RESULTS_DIR = os.path.join(RESULTS_DIR, '{}_{}'.format(datetime.datetime.now().strftime("%Y%h%d_%H_%M_%S"), str(ARGS.outdir)))
46+
results_dir = os.path.join(RESULTS_DIR, '{}_{}'.format(datetime.datetime.now().strftime("%Y%h%d_%H_%M_%S"), str(ARGS.outdir)))
7247

73-
if not os.path.exists(RESULTS_DIR):
74-
os.makedirs(RESULTS_DIR)
48+
if not os.path.exists(results_dir):
49+
os.makedirs(results_dir)
7550

76-
LOG_FILE = os.path.join(RESULTS_DIR, f"{ARGS.logfile}.txt") # write logg to this file
51+
LOG_FILE = os.path.join(results_dir, f"{ARGS.logfile}.txt") # write logg to this file
7752
logging.basicConfig(
7853
level=logging.INFO,
7954
handlers=[
@@ -82,4 +57,4 @@ def main(method):
8257
]
8358
)
8459
solver = ARGS.solver
85-
main(solver)
60+
main(solver, results_dir)

causal_inference/base/lv_system.py

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,31 @@
11
#!/usr/bin/python
22
# -*- coding: utf-8 -*-
33

4+
import numpy as np
45
from causal_inference.config import LV_PARAMS
56

67
class LotkaVolterra():
78
'''
8-
Class simulates predator-prey dynamics and solves it with 4th order Runge-Kutta method.
9+
Base Lotka-Volterra Class that defines a predator-prey system.
910
'''
10-
def __init__(self):
11-
self.A = LV_PARAMS['A']
12-
self.B = LV_PARAMS['B']
13-
self.C = LV_PARAMS['C']
14-
self.D = LV_PARAMS['D']
15-
self.time = LV_PARAMS['INITIAL_TIME']
16-
self.step_size = LV_PARAMS['STEP_SIZE']
17-
self.max_iterations = LV_PARAMS['MAX_ITERATIONS']
11+
def __init__(self,
12+
A=LV_PARAMS['A'], B=LV_PARAMS['B'], C=LV_PARAMS['C'], D=LV_PARAMS['D'],
13+
prey_population=LV_PARAMS['INITIAL_PREY_POPULATION'],
14+
pred_population=LV_PARAMS['INITIAL_PREDATOR_POPULATION'],
15+
total_time=LV_PARAMS['TOTAL_TIME'], step_size=LV_PARAMS['STEP_SIZE'],
16+
max_iter=LV_PARAMS['MAX_ITERATIONS']):
17+
# Lotka-Volterra parameters
18+
self.A = A
19+
self.B = B
20+
self.C = C
21+
self.D = D
1822

19-
self.prey_population = LV_PARAMS['INITIAL_PREY_POPULATION']
20-
self.predator_population = LV_PARAMS['INITIAL_PREDATOR_POPULATION']
23+
self.prey_population = prey_population # Initial prey population
24+
self.predator_population = pred_population # Initial predator population
25+
26+
self.init_time = 0 # initial time
27+
self.total_time = total_time # total time in units
28+
self.step_size = step_size # increment for each time step
29+
self.max_iterations = max_iter # tolerance parameter
30+
31+
self.time_stamps = np.arange(self.init_time, self.total_time, self.step_size)

causal_inference/base/ode_solver.py

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
import numpy as np
66
from scipy import integrate
77

8-
from causal_inference.config import LV_PARAMS
98
from causal_inference.base.lv_system import LotkaVolterra
109

1110
class ODE_solver(LotkaVolterra):
@@ -14,20 +13,28 @@ class ODE_solver(LotkaVolterra):
1413
'''
1514
def __init__(self):
1615
super().__init__()
17-
logging.info('Solving Lotka-Volterra predator-prey dynamics odeint solver')
16+
logging.info('Simulating Lotka-Volterra predator-prey dynamics with odeint solver')
1817

1918
@staticmethod
20-
def LV_derivative(X, t, alpha, beta, delta, gamma):
21-
x, y = X
22-
dotx = x * (alpha - beta * y)
23-
doty = y * (-delta + gamma * x)
19+
def LV_derivative(t, Z, A, B, C, D):
20+
'''
21+
Returns the rate of change of predator and prey population
22+
'''
23+
x, y = Z
24+
dotx = x * (A - B * y)
25+
doty = y * (-C + D * x)
2426
return np.array([dotx, doty])
2527

2628
def _solve(self):
27-
logging.info('Computing population over time...')
28-
t = np.arange(0.,self.max_iterations, self.step_size)
29-
X0 = [self.prey_population, self.predator_population]
30-
res = integrate.odeint(self.LV_derivative, X0, t, args=(self.A, self.B, self.C, self.D))
31-
prey_list, predator_list = res.T
29+
'''
30+
ODE solver that returns the predator and prey populations at each time step in time series.
31+
'''
32+
logging.info(f'Computing population over {self.total_time} generation with step size of {self.step_size}...')
33+
34+
INIT_POP = [self.prey_population, self.predator_population]
35+
sol = integrate.solve_ivp(self.LV_derivative, [self.init_time, self.total_time], INIT_POP, args=(self.A, self.B, self.C, self.D), dense_output=True)
36+
prey_list, predator_list = sol.sol(self.time_stamps)
37+
3238
logging.info('done!')
33-
return prey_list, predator_list
39+
40+
return prey_list, predator_list

causal_inference/base/runge_kutta_solver.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
#!/usr/bin/python
22
# -*- coding: utf-8 -*-
33
import logging
4-
from math import ceil
54
from causal_inference.base.lv_system import LotkaVolterra
65

76
class RungeKuttaSolver(LotkaVolterra):
@@ -11,7 +10,6 @@ class RungeKuttaSolver(LotkaVolterra):
1110
def __init__(self):
1211
super().__init__()
1312
logging.info('Solving Lotka-Volterra predator-prey dynamics with 4th order Runge-Kutta method')
14-
self.time_stamp = [self.time]
1513
self.prey_list = [self.prey_population]
1614
self.predator_list = [self.predator_population]
1715

@@ -22,8 +20,6 @@ def compute_predator_rate(self, current_prey, current_predators):
2220
return - self.C * current_predators + self.D * current_prey * current_predators
2321

2422
def runge_kutta_update(self, current_prey, current_predators):
25-
self.time = self.time + self.step_size
26-
self.time_stamp.append(self.time)
2723

2824
k1_prey = self.step_size * self.compute_prey_rate(current_prey, current_predators)
2925
k1_pred = self.step_size * self.compute_predator_rate(current_prey, current_predators)
@@ -46,11 +42,17 @@ def runge_kutta_update(self, current_prey, current_predators):
4642
return new_prey_population, new_predator_population
4743

4844
def _solve(self):
45+
'''
46+
Runge-Kutta solver that returns the predator and prey populations at each time step in time series.
47+
'''
48+
#initial population
4949
current_prey, current_predators = self.prey_population, self.predator_population
50-
logging.info('Computing population over time...')
51-
for gen_idx in range(ceil(self.max_iterations/self.step_size)):
50+
51+
logging.info(f'Computing population over {self.total_time} generation with step size of {self.step_size}...')
52+
53+
for step_idx in self.time_stamps[1:]:
5254
current_prey, current_predators = self.runge_kutta_update(current_prey, current_predators)
53-
msg= f'Gen: {gen_idx} | Prey population: {current_prey} | Predator population: {current_predators}'
55+
msg= f'Step: {step_idx} | Prey population: {current_prey} | Predator population: {current_predators}'
5456
logging.info(msg)
5557
print('Done!')
5658
return self.prey_list, self.predator_list

causal_inference/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
'C' : 3.0,
1313
'D' : 5.0,
1414
'STEP_SIZE' : 0.01,
15-
'INITIAL_TIME' : 0,
15+
'TOTAL_TIME' : 20,
1616
'INITIAL_PREY_POPULATION' : 60,
1717
'INITIAL_PREDATOR_POPULATION' : 25,
1818
'MAX_ITERATIONS' : 200
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
#!/usr/bin/python
2+
# -*- coding: utf-8 -*-
3+
from os.path import join
4+
import matplotlib.pyplot as plt
5+
6+
def plot_population_over_time(prey_list, predator_list, time_stamps, results_dir, save=True, filename='predator_prey'):
7+
fig = plt.figure(figsize=(15, 5))
8+
ax = fig.add_subplot(2, 1, 1)
9+
PreyLine, = plt.plot(time_stamps, prey_list, color='g')
10+
PredatorsLine, = plt.plot(time_stamps, predator_list, color='r')
11+
ax.set_xscale('log')
12+
13+
plt.legend([PreyLine, PredatorsLine], ['Prey', 'Predators'])
14+
plt.ylabel('Population')
15+
plt.xlabel('Time')
16+
if save:
17+
plt.savefig(join(results_dir, f"{filename}.svg"),
18+
format='svg', transparent=False, bbox_inches='tight')
19+
else:
20+
plt.show()
21+
plt.close()

causal_inference/utils/writer.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
#!/usr/bin/python
2+
# -*- coding: utf-8 -*-
3+
from os.path import join
4+
import h5py
5+
6+
def _save_population(prey_list, predator_list, time_stamps, results_dir):
7+
filename = join(results_dir, 'populations.h5')
8+
hf = h5py.File(filename, 'w')
9+
hf.create_dataset('time_stamp', data=time_stamps)
10+
hf.create_dataset('prey_pop', data=prey_list)
11+
hf.create_dataset('pred_pop', data=predator_list)
12+
hf.close()

0 commit comments

Comments
 (0)