From 40f3e96ac4634627881da63c0cbfc34ce40c412a Mon Sep 17 00:00:00 2001 From: Simon Walker Date: Sun, 7 Feb 2016 20:41:30 +0000 Subject: [PATCH 1/6] New api uses a namedtuple to handle the output result --- testing/test_lweiss.py | 3 ++- testing/test_python_api.py | 38 +++++++++++++++++++++++--------------- testing/test_rv.py | 4 ++-- ttvfast/__init__.py | 21 ++++++++++++++++++--- 4 files changed, 45 insertions(+), 21 deletions(-) diff --git a/testing/test_lweiss.py b/testing/test_lweiss.py index 96a373a..be631b8 100644 --- a/testing/test_lweiss.py +++ b/testing/test_lweiss.py @@ -1,3 +1,4 @@ +import pytest import numpy as np import ttvfast @@ -5,7 +6,7 @@ Based on a bug report supplied by Laren Weiss ''' - +@pytest.mark.skipif(True, reason='Out of date API') def test_application(args): setup = args Time, dt, Total = setup[1:4] diff --git a/testing/test_python_api.py b/testing/test_python_api.py index 0062686..8f4b8e9 100644 --- a/testing/test_python_api.py +++ b/testing/test_python_api.py @@ -2,26 +2,34 @@ import ttvfast - -def test_python_call(stellar_mass, planets, python_args): - Time, dt, Total = python_args - results = ttvfast.ttvfast(planets, stellar_mass, Time, dt, Total) - - python_rows = zip(*results['positions']) - +def check_against_output_file(results): + ''' + Function to check the output of `ttvfast` with the example output file + ''' with open('testing/example_output.txt') as infile: - for i, (python_row, c_row) in enumerate( - zip(python_rows, infile)): + for i, c_row in enumerate(infile): c_row = c_row.strip().split() - vals = (int(c_row[0]), - int(c_row[1]), - float(c_row[2]), - float(c_row[3]), - float(c_row[4])) - assert np.allclose(vals, python_row) + expected = ( + int(c_row[0]), + int(c_row[1]), + float(c_row[2]), + float(c_row[3]), + float(c_row[4]), + ) + result = (results.planets[i], + results.epochs[i], + results.times[i], + results.rsky[i], + results.vsky[i]) + assert np.allclose(result, expected) assert i == 374 +def test_python_call(stellar_mass, planets, python_args): + Time, dt, Total = python_args + results = ttvfast.ttvfast(planets, stellar_mass, Time, dt, Total) + check_against_output_file(results) + def test_module_docstring_is_present(): assert 'Fast TTV computation' in ttvfast.__doc__ diff --git a/testing/test_rv.py b/testing/test_rv.py index 426b916..447e0e9 100644 --- a/testing/test_rv.py +++ b/testing/test_rv.py @@ -13,10 +13,10 @@ def test_rv_given(stellar_mass, planets, python_args): Time, dt, Total = python_args results = ttvfast.ttvfast(planets, stellar_mass, Time, dt, Total, rv_times=rv_times) - assert np.allclose(results['rv'], expected) + assert np.allclose(results.rv, expected) def test_no_rv_given(stellar_mass, planets, python_args): Time, dt, Total = python_args results = ttvfast.ttvfast(planets, stellar_mass, Time, dt, Total) - assert results['rv'] is None + assert results.rv is None diff --git a/ttvfast/__init__.py b/ttvfast/__init__.py index e4b7ecc..08d2afc 100644 --- a/ttvfast/__init__.py +++ b/ttvfast/__init__.py @@ -2,9 +2,17 @@ "Fast TTV computation" + +__all__ = ['ttvfast'] + + +from collections import namedtuple from ._ttvfast import _ttvfast as _ttvfast_fn from . import models +TTVFastResult = namedtuple('TTVFastResult', [ + 'planets', 'epochs', 'times', 'rsky', 'vsky', 'rv', +]) __all__ = ['ttvfast'] @@ -25,8 +33,15 @@ def ttvfast(planets, stellar_mass, time, dt, total, rv_times=None): input_flag = 0 len_rv = len(rv_times) if rv_times is not None else 0 - positions, rv = _ttvfast_fn( - params, dt, time, total, n_plan, input_flag, len_rv, rv_times) - return {'positions': positions, 'rv': rv} + positions, rv = _ttvfast_fn(params, dt, time, total, n_plan, input_flag, len_rv, rv_times) + + return TTVFastResult( + planets=positions[0], + epochs=positions[1], + times=positions[2], + rsky=positions[3], + vsky=positions[4], + rv=rv + ) __all__ = ['ttvfast'] From f51b59a819787f5f206e8f6bd6480db92875a0ea Mon Sep 17 00:00:00 2001 From: Simon Walker Date: Sun, 7 Feb 2016 20:43:32 +0000 Subject: [PATCH 2/6] Return namedtuple of numpy arrays --- setup.py | 1 + ttvfast/__init__.py | 13 +++++++------ 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/setup.py b/setup.py index 6564c57..9a10d3e 100644 --- a/setup.py +++ b/setup.py @@ -29,6 +29,7 @@ author_email='s.r.walker101@googlemail.com', license='GPL', packages=['ttvfast', ], + install_requires=['numpy', ], ext_modules=[ttvfast, ], classifiers=[ 'Development Status :: 3 - Alpha', diff --git a/ttvfast/__init__.py b/ttvfast/__init__.py index 08d2afc..723e019 100644 --- a/ttvfast/__init__.py +++ b/ttvfast/__init__.py @@ -7,6 +7,7 @@ from collections import namedtuple +import numpy as np from ._ttvfast import _ttvfast as _ttvfast_fn from . import models @@ -36,12 +37,12 @@ def ttvfast(planets, stellar_mass, time, dt, total, rv_times=None): positions, rv = _ttvfast_fn(params, dt, time, total, n_plan, input_flag, len_rv, rv_times) return TTVFastResult( - planets=positions[0], - epochs=positions[1], - times=positions[2], - rsky=positions[3], - vsky=positions[4], - rv=rv + planets=np.array(positions[0]), + epochs=np.array(positions[1]), + times=np.array(positions[2]), + rsky=np.array(positions[3]), + vsky=np.array(positions[4]), + rv=np.array(rv) if rv else None, ) __all__ = ['ttvfast'] From 00993d8b9d18d651e00b64050882fc6de0d465c5 Mon Sep 17 00:00:00 2001 From: Simon Walker Date: Sun, 24 Jul 2016 11:14:22 +0100 Subject: [PATCH 3/6] Add pylint and flake8 to dev requirements --- requirements.txt | 2 ++ 1 file changed, 2 insertions(+) diff --git a/requirements.txt b/requirements.txt index d5cb77f..f2a0f4f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,3 +2,5 @@ pytest>=2.7.3 six==1.9.0 numpy==1.9.1 pytest-cov==2.2.1 +pylint==1.6.4 +flake8==2.6.2 From 8cebfb836156a64206e28a941f57eac7bb96df3f Mon Sep 17 00:00:00 2001 From: Simon Walker Date: Sun, 24 Jul 2016 11:16:23 +0100 Subject: [PATCH 4/6] Fix linting errors --- testing/test_lweiss.py | 1 + testing/test_python_api.py | 2 ++ ttvfast/__init__.py | 9 ++++----- 3 files changed, 7 insertions(+), 5 deletions(-) diff --git a/testing/test_lweiss.py b/testing/test_lweiss.py index be631b8..db75b59 100644 --- a/testing/test_lweiss.py +++ b/testing/test_lweiss.py @@ -6,6 +6,7 @@ Based on a bug report supplied by Laren Weiss ''' + @pytest.mark.skipif(True, reason='Out of date API') def test_application(args): setup = args diff --git a/testing/test_python_api.py b/testing/test_python_api.py index 8f4b8e9..83af9a3 100644 --- a/testing/test_python_api.py +++ b/testing/test_python_api.py @@ -2,6 +2,7 @@ import ttvfast + def check_against_output_file(results): ''' Function to check the output of `ttvfast` with the example output file @@ -25,6 +26,7 @@ def check_against_output_file(results): assert i == 374 + def test_python_call(stellar_mass, planets, python_args): Time, dt, Total = python_args results = ttvfast.ttvfast(planets, stellar_mass, Time, dt, Total) diff --git a/ttvfast/__init__.py b/ttvfast/__init__.py index 723e019..3d45d22 100644 --- a/ttvfast/__init__.py +++ b/ttvfast/__init__.py @@ -2,15 +2,13 @@ "Fast TTV computation" - -__all__ = ['ttvfast'] - - from collections import namedtuple import numpy as np from ._ttvfast import _ttvfast as _ttvfast_fn from . import models +__all__ = ['ttvfast'] + TTVFastResult = namedtuple('TTVFastResult', [ 'planets', 'epochs', 'times', 'rsky', 'vsky', 'rv', ]) @@ -34,7 +32,8 @@ def ttvfast(planets, stellar_mass, time, dt, total, rv_times=None): input_flag = 0 len_rv = len(rv_times) if rv_times is not None else 0 - positions, rv = _ttvfast_fn(params, dt, time, total, n_plan, input_flag, len_rv, rv_times) + positions, rv = _ttvfast_fn(params, dt, time, total, n_plan, + input_flag, len_rv, rv_times) return TTVFastResult( planets=np.array(positions[0]), From a04c9b2570a1f223eddd2ae51431ceef0d7b8302 Mon Sep 17 00:00:00 2001 From: Simon Walker Date: Sun, 24 Jul 2016 11:21:56 +0100 Subject: [PATCH 5/6] Testing lweiss results with new api --- testing/test_lweiss.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/testing/test_lweiss.py b/testing/test_lweiss.py index db75b59..29959b5 100644 --- a/testing/test_lweiss.py +++ b/testing/test_lweiss.py @@ -1,4 +1,3 @@ -import pytest import numpy as np import ttvfast @@ -7,7 +6,6 @@ ''' -@pytest.mark.skipif(True, reason='Out of date API') def test_application(args): setup = args Time, dt, Total = setup[1:4] @@ -26,8 +24,13 @@ def test_application(args): assert 0.9 < stellar_mass < 1.0 results = ttvfast.ttvfast(planets, stellar_mass, Time, dt, Total) - python_rows = list(zip(*results['positions'])) + test_row = 22 expected = [1, 7, -8.828648752325788e+02, 6.363231859868642e-03, 4.321183741781629e-02] - assert np.allclose(python_rows[22], expected) + found = [ + results.planets[test_row], results.epochs[test_row], + results.times[test_row], results.rsky[test_row], + results.vsky[test_row] + ] + assert np.allclose(found, expected) From d5379d6f75d96dff5489786c39ece4c498af5616 Mon Sep 17 00:00:00 2001 From: Simon Walker Date: Sun, 24 Jul 2016 16:01:59 +0100 Subject: [PATCH 6/6] Add #row method to TTVFastResult --- testing/test_lweiss.py | 8 +------ testing/test_python_api.py | 47 ++++++++++++++++++++++++++++++++++++++ ttvfast/__init__.py | 27 ++++++++++++++++++++-- 3 files changed, 73 insertions(+), 9 deletions(-) diff --git a/testing/test_lweiss.py b/testing/test_lweiss.py index 29959b5..f7a884c 100644 --- a/testing/test_lweiss.py +++ b/testing/test_lweiss.py @@ -25,12 +25,6 @@ def test_application(args): assert 0.9 < stellar_mass < 1.0 results = ttvfast.ttvfast(planets, stellar_mass, Time, dt, Total) - test_row = 22 expected = [1, 7, -8.828648752325788e+02, 6.363231859868642e-03, 4.321183741781629e-02] - found = [ - results.planets[test_row], results.epochs[test_row], - results.times[test_row], results.rsky[test_row], - results.vsky[test_row] - ] - assert np.allclose(found, expected) + assert np.allclose(results.row(22), expected) diff --git a/testing/test_python_api.py b/testing/test_python_api.py index 83af9a3..fcd6f26 100644 --- a/testing/test_python_api.py +++ b/testing/test_python_api.py @@ -1,4 +1,5 @@ import numpy as np +import pytest import ttvfast @@ -39,3 +40,49 @@ def test_module_docstring_is_present(): def test_ttvfast_docstring_is_present(): assert 'https://github.com/kdeck/TTVFast' in ttvfast.ttvfast.__doc__ + + +class TestTTVFastResult(object): + + @pytest.fixture(scope='module') + def result_with_rv(self): + return ttvfast.TTVFastResult( + planets=np.array([0, 1, 0, 1]), + epochs=np.random.uniform(0., 1., size=4), + times=np.random.uniform(0., 1., size=4), + rsky=np.random.uniform(-1., 1., size=4), + vsky=np.random.uniform(-1., 1., size=4), + rv=np.random.uniform(-1., 1., size=4), + ) + + @pytest.fixture(scope='module') + def result_without_rv(self): + return ttvfast.TTVFastResult( + planets=np.array([0, 1, 0, 1]), + epochs=np.random.uniform(0., 1., size=4), + times=np.random.uniform(0., 1., size=4), + rsky=np.random.uniform(-1., 1., size=4), + vsky=np.random.uniform(-1., 1., size=4), + rv=None, + ) + + def test_get_length(self, result_without_rv): + assert len(result_without_rv) == 4 + + def test_get_row_without_rv(self, result_without_rv): + keys = ['planets', 'epochs', 'times', 'rsky', 'vsky'] + for i in range(4): + expected = [getattr(result_without_rv, key)[i] for key in keys] + assert np.allclose(result_without_rv.row(i), expected) + + def test_get_row_with_rv(self, result_with_rv): + keys = ['planets', 'epochs', 'times', 'rsky', 'vsky', 'rv'] + for i in range(4): + expected = [getattr(result_with_rv, key)[i] for key in keys] + assert np.allclose(result_with_rv.row(i), expected) + + def test_get_invalid_row(self, result_without_rv): + with pytest.raises(IndexError) as exc_info: + result_without_rv.row(100) + + assert 'Index 100 out of bounds for array length 4' in str(exc_info) diff --git a/ttvfast/__init__.py b/ttvfast/__init__.py index 3d45d22..2db4ea9 100644 --- a/ttvfast/__init__.py +++ b/ttvfast/__init__.py @@ -9,11 +9,34 @@ __all__ = ['ttvfast'] -TTVFastResult = namedtuple('TTVFastResult', [ +TTVFastResultBase = namedtuple('TTVFastResultBase', [ 'planets', 'epochs', 'times', 'rsky', 'vsky', 'rv', ]) -__all__ = ['ttvfast'] + +class TTVFastResult(TTVFastResultBase): + def __len__(self): + '''Enables the `len` function to work''' + return self.times.size + + def row(self, index): + '''Return a single entry into the results array''' + if index >= len(self): + raise IndexError( + "Index {index} out of bounds for array length {length}".format( + index=index, length=len(self))) + + arr = [ + self.planets[index], + self.epochs[index], + self.times[index], + self.rsky[index], + self.vsky[index], + ] + if self.rv is not None: + arr.append(self.rv[index]) + + return arr def ttvfast(planets, stellar_mass, time, dt, total, rv_times=None):