|
| 1 | +import numpy as np |
| 2 | +from scipy.linalg import solve |
| 3 | +import copy as cp |
| 4 | +from pipt.misc_tools import analysis_tools as at |
| 5 | + |
| 6 | +class margIS_update(): |
| 7 | + |
| 8 | + """ |
| 9 | + Placeholder for private margIS method |
| 10 | + """ |
| 11 | + def update(self): |
| 12 | + if self.iteration == 1: # method requires some initiallization |
| 13 | + self.aug_prior = cp.deepcopy(at.aug_state(self.prior_state, self.list_states)) |
| 14 | + self.mean_prior = self.aug_prior.mean(axis=1) |
| 15 | + self.X = (self.aug_prior - np.dot(np.resize(self.mean_prior, (len(self.mean_prior), 1)), |
| 16 | + np.ones((1, self.ne)))) |
| 17 | + self.W = np.eye(self.ne) |
| 18 | + self.current_w = np.zeros((self.ne,)) |
| 19 | + self.E = np.dot(self.real_obs_data, self.proj) |
| 20 | + |
| 21 | + M = len(self.real_obs_data) |
| 22 | + Ytmp = solve(self.W, self.proj) |
| 23 | + if len(self.scale_data.shape) == 1: |
| 24 | + Y = np.dot(np.expand_dims(self.scale_data ** (-1), axis=1), np.ones((1, self.ne))) * \ |
| 25 | + np.dot(self.aug_pred_data, Ytmp) |
| 26 | + else: |
| 27 | + Y = solve(self.scale_data, np.dot(self.aug_pred_data, Ytmp)) |
| 28 | + |
| 29 | + pred_data_mean = np.mean(self.aug_pred_data, 1) |
| 30 | + delta_d = (self.obs_data_vector - pred_data_mean) |
| 31 | + |
| 32 | + if len(self.cov_data.shape) == 1: |
| 33 | + S = np.dot(delta_d, (self.cov_data**(-1)) * delta_d) |
| 34 | + Ratio = M / S |
| 35 | + grad_lklhd = np.dot(Y.T * Ratio, (self.cov_data**(-1)) * delta_d) |
| 36 | + grad_prior = (self.ne - 1) * self.current_w |
| 37 | + self.C_w = (np.dot(Ratio * Y.T, np.dot(np.diag(self.cov_data ** (-1)), Y)) + (self.ne - 1) * np.eye(self.ne)) |
| 38 | + else: |
| 39 | + S = np.dot(delta_d, solve(self.cov_data, delta_d)) |
| 40 | + Ratio = M / S |
| 41 | + grad_lklhd = np.dot(Y.T * Ratio, solve(self.cov_data, delta_d)) |
| 42 | + grad_prior = (self.ne - 1) * self.current_w |
| 43 | + self.C_w = (np.dot(Ratio * Y.T, solve(self.cov_data, Y)) + (self.ne - 1) * np.eye(self.ne)) |
| 44 | + |
| 45 | + self.sqrt_w_step = solve(self.C_w, grad_prior + grad_lklhd) |
0 commit comments