diff --git a/requirements.txt b/requirements.txt index d5cb77f..f2a0f4f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,3 +2,5 @@ pytest>=2.7.3 six==1.9.0 numpy==1.9.1 pytest-cov==2.2.1 +pylint==1.6.4 +flake8==2.6.2 diff --git a/setup.py b/setup.py index 6564c57..9a10d3e 100644 --- a/setup.py +++ b/setup.py @@ -29,6 +29,7 @@ author_email='s.r.walker101@googlemail.com', license='GPL', packages=['ttvfast', ], + install_requires=['numpy', ], ext_modules=[ttvfast, ], classifiers=[ 'Development Status :: 3 - Alpha', diff --git a/testing/test_lweiss.py b/testing/test_lweiss.py index 96a373a..f7a884c 100644 --- a/testing/test_lweiss.py +++ b/testing/test_lweiss.py @@ -24,8 +24,7 @@ def test_application(args): assert 0.9 < stellar_mass < 1.0 results = ttvfast.ttvfast(planets, stellar_mass, Time, dt, Total) - python_rows = list(zip(*results['positions'])) expected = [1, 7, -8.828648752325788e+02, 6.363231859868642e-03, 4.321183741781629e-02] - assert np.allclose(python_rows[22], expected) + assert np.allclose(results.row(22), expected) diff --git a/testing/test_python_api.py b/testing/test_python_api.py index 0062686..fcd6f26 100644 --- a/testing/test_python_api.py +++ b/testing/test_python_api.py @@ -1,31 +1,88 @@ import numpy as np +import pytest import ttvfast -def test_python_call(stellar_mass, planets, python_args): - Time, dt, Total = python_args - results = ttvfast.ttvfast(planets, stellar_mass, Time, dt, Total) - - python_rows = zip(*results['positions']) - +def check_against_output_file(results): + ''' + Function to check the output of `ttvfast` with the example output file + ''' with open('testing/example_output.txt') as infile: - for i, (python_row, c_row) in enumerate( - zip(python_rows, infile)): + for i, c_row in enumerate(infile): c_row = c_row.strip().split() - vals = (int(c_row[0]), - int(c_row[1]), - float(c_row[2]), - float(c_row[3]), - float(c_row[4])) - assert np.allclose(vals, python_row) + expected = ( + int(c_row[0]), + int(c_row[1]), + float(c_row[2]), + float(c_row[3]), + float(c_row[4]), + ) + result = (results.planets[i], + results.epochs[i], + results.times[i], + results.rsky[i], + results.vsky[i]) + assert np.allclose(result, expected) assert i == 374 +def test_python_call(stellar_mass, planets, python_args): + Time, dt, Total = python_args + results = ttvfast.ttvfast(planets, stellar_mass, Time, dt, Total) + check_against_output_file(results) + + def test_module_docstring_is_present(): assert 'Fast TTV computation' in ttvfast.__doc__ def test_ttvfast_docstring_is_present(): assert 'https://github.com/kdeck/TTVFast' in ttvfast.ttvfast.__doc__ + + +class TestTTVFastResult(object): + + @pytest.fixture(scope='module') + def result_with_rv(self): + return ttvfast.TTVFastResult( + planets=np.array([0, 1, 0, 1]), + epochs=np.random.uniform(0., 1., size=4), + times=np.random.uniform(0., 1., size=4), + rsky=np.random.uniform(-1., 1., size=4), + vsky=np.random.uniform(-1., 1., size=4), + rv=np.random.uniform(-1., 1., size=4), + ) + + @pytest.fixture(scope='module') + def result_without_rv(self): + return ttvfast.TTVFastResult( + planets=np.array([0, 1, 0, 1]), + epochs=np.random.uniform(0., 1., size=4), + times=np.random.uniform(0., 1., size=4), + rsky=np.random.uniform(-1., 1., size=4), + vsky=np.random.uniform(-1., 1., size=4), + rv=None, + ) + + def test_get_length(self, result_without_rv): + assert len(result_without_rv) == 4 + + def test_get_row_without_rv(self, result_without_rv): + keys = ['planets', 'epochs', 'times', 'rsky', 'vsky'] + for i in range(4): + expected = [getattr(result_without_rv, key)[i] for key in keys] + assert np.allclose(result_without_rv.row(i), expected) + + def test_get_row_with_rv(self, result_with_rv): + keys = ['planets', 'epochs', 'times', 'rsky', 'vsky', 'rv'] + for i in range(4): + expected = [getattr(result_with_rv, key)[i] for key in keys] + assert np.allclose(result_with_rv.row(i), expected) + + def test_get_invalid_row(self, result_without_rv): + with pytest.raises(IndexError) as exc_info: + result_without_rv.row(100) + + assert 'Index 100 out of bounds for array length 4' in str(exc_info) diff --git a/testing/test_rv.py b/testing/test_rv.py index 426b916..447e0e9 100644 --- a/testing/test_rv.py +++ b/testing/test_rv.py @@ -13,10 +13,10 @@ def test_rv_given(stellar_mass, planets, python_args): Time, dt, Total = python_args results = ttvfast.ttvfast(planets, stellar_mass, Time, dt, Total, rv_times=rv_times) - assert np.allclose(results['rv'], expected) + assert np.allclose(results.rv, expected) def test_no_rv_given(stellar_mass, planets, python_args): Time, dt, Total = python_args results = ttvfast.ttvfast(planets, stellar_mass, Time, dt, Total) - assert results['rv'] is None + assert results.rv is None diff --git a/ttvfast/__init__.py b/ttvfast/__init__.py index e4b7ecc..2db4ea9 100644 --- a/ttvfast/__init__.py +++ b/ttvfast/__init__.py @@ -2,12 +2,42 @@ "Fast TTV computation" +from collections import namedtuple +import numpy as np from ._ttvfast import _ttvfast as _ttvfast_fn from . import models - __all__ = ['ttvfast'] +TTVFastResultBase = namedtuple('TTVFastResultBase', [ + 'planets', 'epochs', 'times', 'rsky', 'vsky', 'rv', +]) + + +class TTVFastResult(TTVFastResultBase): + def __len__(self): + '''Enables the `len` function to work''' + return self.times.size + + def row(self, index): + '''Return a single entry into the results array''' + if index >= len(self): + raise IndexError( + "Index {index} out of bounds for array length {length}".format( + index=index, length=len(self))) + + arr = [ + self.planets[index], + self.epochs[index], + self.times[index], + self.rsky[index], + self.vsky[index], + ] + if self.rv is not None: + arr.append(self.rv[index]) + + return arr + def ttvfast(planets, stellar_mass, time, dt, total, rv_times=None): ''' @@ -25,8 +55,16 @@ def ttvfast(planets, stellar_mass, time, dt, total, rv_times=None): input_flag = 0 len_rv = len(rv_times) if rv_times is not None else 0 - positions, rv = _ttvfast_fn( - params, dt, time, total, n_plan, input_flag, len_rv, rv_times) - return {'positions': positions, 'rv': rv} + positions, rv = _ttvfast_fn(params, dt, time, total, n_plan, + input_flag, len_rv, rv_times) + + return TTVFastResult( + planets=np.array(positions[0]), + epochs=np.array(positions[1]), + times=np.array(positions[2]), + rsky=np.array(positions[3]), + vsky=np.array(positions[4]), + rv=np.array(rv) if rv else None, + ) __all__ = ['ttvfast']