Skip to content

Commit 1c835ac

Browse files
committed
made random_walk_idxs compatible with np.slice
1 parent c20232c commit 1c835ac

File tree

2 files changed

+35
-6
lines changed

2 files changed

+35
-6
lines changed

src/qinfer/derived_models.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -620,8 +620,10 @@ class GaussianRandomWalkModel(DerivedModel):
620620
621621
:param Model underlying_model: Model representing the likelihood with no
622622
random walk added.
623-
:param random_walk_idxs: A list of model parameter indeces to add the
624-
random walk to.
623+
:param random_walk_idxs: A list or ``np.slice`` of
624+
``underlying_model`` model parameter indeces to add the random walk to.
625+
Indeces larger than ``underlying_model.n_modelparams`` should not
626+
be touched.
625627
:param fixed_covariance: An ``np.ndarray`` specifying the fixed covariance
626628
matrix (or diagonal thereof if ``diagonal`` is ``True``) of the
627629
gaussian distribution. If set to ``None`` (default), this matrix is
@@ -649,14 +651,18 @@ def __init__(
649651
):
650652

651653
self._diagonal = diagonal
652-
self._rw_idxs = np.arange(underlying_model.n_modelparams).astype(np.int) \
654+
self._rw_idxs = np.s_[:underlying_model.n_modelparams] \
653655
if random_walk_idxs == 'all' else random_walk_idxs
654656

657+
explicit_idxs = np.arange(underlying_model.n_modelparams)[self._rw_idxs]
658+
if explicit_idxs.size == 0:
659+
raise IndexError('At least one model parameter must take a random walk.')
660+
655661
self._rw_names = [
656662
underlying_model.modelparam_names[idx]
657-
for idx in self._rw_idxs
663+
for idx in explicit_idxs
658664
]
659-
self._n_rw = len(self._rw_idxs)
665+
self._n_rw = len(explicit_idxs)
660666

661667
self._srw_names = []
662668
if fixed_covariance is None:
@@ -700,6 +706,9 @@ def __init__(
700706
)
701707

702708
super(GaussianRandomWalkModel, self).__init__(underlying_model)
709+
710+
if np.max(np.arange(self.n_modelparams)[self._rw_idxs]) > np.max(explicit_idxs):
711+
raise IndexError('random_walk_idxs out of bounds; must index (a subset of ) underlying_model modelparams.')
703712

704713
if scale_mult is None:
705714
self._scale_mult_fcn = (lambda expparams: 1)

src/qinfer/tests/test_concrete_models.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -407,7 +407,7 @@ def instantiate_model(self):
407407
m,
408408
fixed_covariance = cov,
409409
diagonal = False,
410-
random_walk_idxs = [1,2,4],
410+
random_walk_idxs = np.s_[:6:2],
411411
model_transformation = (from_simplex, to_simplex),
412412
scale_mult = 'n_meas'
413413
)
@@ -480,3 +480,23 @@ def test_est_update_covariance(self):
480480
cov = self.model.est_update_covariance(self.modelparams)
481481
eigs, v = np.linalg.eig(cov)
482482
assert(np.greater_equal(eigs, -1e-10).all())
483+
484+
class TestGaussianRandomWalkModel5(DerandomizedTestCase):
485+
"""
486+
Tests miscillaneous properties of GaussianRandomWalkModel.
487+
"""
488+
489+
def test_indexing(self):
490+
model = lambda slice: GaussianRandomWalkModel(
491+
MultinomialModel(NDieModel(n=6)),
492+
random_walk_idxs = slice
493+
)
494+
495+
assert(model('all').n_modelparams == 12)
496+
assert(model(np.s_[:6]).n_modelparams == 12)
497+
assert(model(np.s_[:6:2]).n_modelparams == 9)
498+
assert(model([2,3,4]).n_modelparams == 9)
499+
500+
self.assertRaises(IndexError, model, np.s_[:7])
501+
self.assertRaises(IndexError, model, np.s_[6:])
502+
self.assertRaises(IndexError, model, [1,2,8])

0 commit comments

Comments
 (0)