Skip to content

Commit 31ef01b

Browse files
committed
fixed est_update_covariance
1 parent 1c835ac commit 31ef01b

File tree

1 file changed

+6
-7
lines changed

1 file changed

+6
-7
lines changed

src/qinfer/derived_models.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -760,25 +760,24 @@ def simulate_experiment(self, modelparams, expparams, repeat=1):
760760

761761
def est_update_covariance(self, modelparams):
762762
"""
763-
Returns the covariance of the gaussion noise process for one
763+
Returns the covariance of the gaussian noise process for one
764764
unit step. In the case where the covariance is being learned,
765765
the expected covariance matrix is returned.
766766
767767
:param modelparams: Shape `(n_models, n_modelparams)` shape array
768768
of model parameters.
769769
"""
770770
if self._diagonal:
771-
scale = (self._fixed_scale if self._has_fixed_covariance
772-
else np.mean(modelparams[:, self._srw_idxs], axis=0))
773-
cov = np.diag(scale ** 2)
771+
cov = (self._fixed_scale ** 2 if self._has_fixed_covariance \
772+
else np.mean(modelparams[:, self._srw_idxs] ** 2, axis=0))
773+
cov = np.diag(cov)
774774
else:
775775
if self._has_fixed_covariance:
776-
chol = self._fixed_chol
776+
cov = np.dot(self._fixed_chol, self._fixed_chol.T)
777777
else:
778778
chol = np.zeros((modelparams.shape[0], self._n_rw, self._n_rw))
779779
chol[(np.s_[:],) + self._srw_tri_idxs] = modelparams[:, self._srw_idxs]
780-
chol = np.mean(chol, axis=0)
781-
cov = np.dot(chol, chol.T)
780+
cov = np.mean(np.einsum('ijk,ilk->ijl', chol, chol), axis=0)
782781
return cov
783782

784783
def update_timestep(self, modelparams, expparams):

0 commit comments

Comments
 (0)