Skip to content

Commit b4c4924

Browse files
authored
Merge pull request #132 from ihincks/fix-test_model
Fix test_model
2 parents 4fd1e4a + 6c9d7e0 commit b4c4924

File tree

7 files changed

+63
-10
lines changed

7 files changed

+63
-10
lines changed

.travis.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ install:
5353
# Before proceeding to install anything else, we debu
5454
# by making sure that the conda install didn't break matplotlib.
5555
- python -c "import matplotlib.pyplot as plt; print('Using MPL backend:'); print(plt.get_backend())"
56+
- pip install --upgrade pip
5657
- pip install -r requirements.txt
5758
# Before proceeding, we pause to list all installed conda and pip
5859
# packages for later debugging.

doc/source/apiref/test_models.rst

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,4 +39,16 @@ built on top of QInfer.
3939

4040
.. autoclass:: NDieModel
4141
:members:
42+
43+
Custom Models
44+
-------------
45+
46+
Writing custom models is standard practice for QInfer users.
47+
See :ref:`CustomModels`.
48+
49+
.. currentmodule:: qinfer.tests.base_test
50+
51+
:meth:`test_model` - Method to run suite of tests on a model instance
52+
---------------------------------------------------------------------
4253

54+
.. autofunction:: test_model

doc/source/guide/models.rst

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -319,6 +319,8 @@ True
319319
>>> L.shape == (1, 100, 9)
320320
True
321321

322+
.. _CustomModels:
323+
322324
Implementing Custom Simulators and Models
323325
-----------------------------------------
324326

@@ -414,6 +416,33 @@ True
414416
>>> D.shape == (2, 10000, 81)
415417
True
416418

419+
Finally, we mention a useful tool for doing a set of
420+
tests on the custom model, which make sure its pieces are working as
421+
expected. These tests look at things like data types and index dimensions of
422+
various functions. They also plug the outputs of some methods into the inputs
423+
of other methods, and so forth. Although they can't check the statistical
424+
soundness of your model, if they all pass, you can be pretty confident
425+
you won't run into weird indexing bugs in the future.
426+
427+
We just need to pass :func:`~qinfer.tests.test_model` an instance of the
428+
custom model, a prior that samples valid model parameters, and an array of valid
429+
``expparams``.
430+
431+
>>> from qinfer.tests import test_model
432+
>>> from qinfer import UniformDistribution
433+
>>> prior = UniformDistribution([[0,1],[0,1]])
434+
>>> test_model(mcm, prior, expparams)
435+
436+
.. code-block:: None
437+
:emphasize-lines: 1,2,3,4,5
438+
439+
.......
440+
----------------------------------------------------------------------
441+
Ran 7 tests in 0.013s
442+
443+
OK
444+
445+
417446
.. note::
418447

419448
Creating ``expparams`` as an empty array and filling it by field name is a
@@ -456,4 +485,3 @@ which is discussed in more detail in :ref:`perf_testing_guide`. Roughly,
456485
this model causes the likeihood functions calculated by its underlying model
457486
to be subject to random noise, so that the robustness of an inference algorithm
458487
against such noise can be tested.
459-

setup.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,8 @@ def write_version(filename=VERSION_TARGET):
5555
packages=[
5656
'qinfer',
5757
'qinfer._lib',
58-
'qinfer.tomography'
58+
'qinfer.tomography',
59+
'qinfer.tests'
5960
],
6061
keywords=['quantum', 'Bayesian', 'estimation'],
6162
description=

src/qinfer/tests/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
from __future__ import absolute_import
22
from qinfer.tests import (
33
test_distributions, base_test, test_precession_model)
4+
from qinfer.tests.base_test import test_model

src/qinfer/tests/base_test.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
import numpy as np
3838
from numpy.testing import assert_equal, assert_almost_equal
3939
import unittest
40-
from qinfer import Domain, FiniteOutcomeModel
40+
from qinfer import Domain, Model, Simulatable, FiniteOutcomeModel, DifferentiableModel
4141

4242
from contextlib import contextmanager
4343

@@ -56,11 +56,11 @@ def test_model(model, prior, expparams, stream=sys.stderr):
5656
"""
5757

5858
if isinstance(model, DifferentiableModel):
59-
test_class = TestConcreteDifferentiableModel
59+
test_class = ConcreteDifferentiableModelTest
6060
elif isinstance(model, Model):
61-
test_class = TestConcreteModel
61+
test_class = ConcreteModelTest
6262
elif isinstance(model, Simulatable):
63-
test_class = TestConcreteSimulatable
63+
test_class = ConcreteSimulatableTest
6464
else:
6565
raise ValueError("Given model has unrecognized type.")
6666

@@ -72,9 +72,9 @@ def instantiate_prior(self):
7272
def instantiate_expparams(self):
7373
return expparams
7474

75-
test = unittest.TestSuite((TestGivenModel, ))
75+
suite = unittest.TestLoader().loadTestsFromTestCase(TestGivenModel)
7676
runner = unittest.TextTestRunner(stream=stream)
77-
runner.run(test)
77+
runner.run(suite)
7878

7979
@contextmanager
8080
def assert_warns(category):

src/qinfer/tests/test_test.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#!/usr/bin/python
22
# -*- coding: utf-8 -*-
33
##
4-
# test_smc.py: Checks that utilities for unit testing
4+
# test_test.py: Checks that utilities for unit testing
55
# actually run the tests we expect.
66
##
77
# © 2014 Chris Ferrie (csferrie@gmail.com) and
@@ -37,7 +37,10 @@
3737
import numpy as np
3838
from numpy.testing import assert_equal, assert_almost_equal
3939

40-
from qinfer.tests.base_test import DerandomizedTestCase, MockModel, assert_warns
40+
from qinfer import UniformDistribution
41+
from qinfer.tests.base_test import (
42+
DerandomizedTestCase, MockModel, assert_warns, test_model
43+
)
4144

4245
## TESTS #####################################################################
4346

@@ -51,3 +54,10 @@ def test_assert_warns_ok(self):
5154
def test_assert_warns_nowarn(self):
5255
with assert_warns(RuntimeWarning):
5356
pass
57+
58+
def test_test_model_runs(self):
59+
model = MockModel()
60+
prior = UniformDistribution(np.array([[10,12],[2,3]]))
61+
eps = np.arange(10,20).astype(model.expparams_dtype)
62+
test_model(model, prior, eps)
63+

0 commit comments

Comments
 (0)