Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,5 @@ pytest>=2.7.3
six==1.9.0
numpy==1.9.1
pytest-cov==2.2.1
pylint==1.6.4
flake8==2.6.2
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
author_email='s.r.walker101@googlemail.com',
license='GPL',
packages=['ttvfast', ],
install_requires=['numpy', ],
ext_modules=[ttvfast, ],
classifiers=[
'Development Status :: 3 - Alpha',
Expand Down
3 changes: 1 addition & 2 deletions testing/test_lweiss.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,7 @@ def test_application(args):

assert 0.9 < stellar_mass < 1.0
results = ttvfast.ttvfast(planets, stellar_mass, Time, dt, Total)
python_rows = list(zip(*results['positions']))

expected = [1, 7, -8.828648752325788e+02, 6.363231859868642e-03,
4.321183741781629e-02]
assert np.allclose(python_rows[22], expected)
assert np.allclose(results.row(22), expected)
85 changes: 71 additions & 14 deletions testing/test_python_api.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,88 @@
import numpy as np
import pytest

import ttvfast


def test_python_call(stellar_mass, planets, python_args):
Time, dt, Total = python_args
results = ttvfast.ttvfast(planets, stellar_mass, Time, dt, Total)

python_rows = zip(*results['positions'])

def check_against_output_file(results):
'''
Function to check the output of `ttvfast` with the example output file
'''
with open('testing/example_output.txt') as infile:
for i, (python_row, c_row) in enumerate(
zip(python_rows, infile)):
for i, c_row in enumerate(infile):
c_row = c_row.strip().split()
vals = (int(c_row[0]),
int(c_row[1]),
float(c_row[2]),
float(c_row[3]),
float(c_row[4]))
assert np.allclose(vals, python_row)
expected = (
int(c_row[0]),
int(c_row[1]),
float(c_row[2]),
float(c_row[3]),
float(c_row[4]),
)
result = (results.planets[i],
results.epochs[i],
results.times[i],
results.rsky[i],
results.vsky[i])
assert np.allclose(result, expected)

assert i == 374


def test_python_call(stellar_mass, planets, python_args):
Time, dt, Total = python_args
results = ttvfast.ttvfast(planets, stellar_mass, Time, dt, Total)
check_against_output_file(results)


def test_module_docstring_is_present():
assert 'Fast TTV computation' in ttvfast.__doc__


def test_ttvfast_docstring_is_present():
assert 'https://github.com/kdeck/TTVFast' in ttvfast.ttvfast.__doc__


class TestTTVFastResult(object):

@pytest.fixture(scope='module')
def result_with_rv(self):
return ttvfast.TTVFastResult(
planets=np.array([0, 1, 0, 1]),
epochs=np.random.uniform(0., 1., size=4),
times=np.random.uniform(0., 1., size=4),
rsky=np.random.uniform(-1., 1., size=4),
vsky=np.random.uniform(-1., 1., size=4),
rv=np.random.uniform(-1., 1., size=4),
)

@pytest.fixture(scope='module')
def result_without_rv(self):
return ttvfast.TTVFastResult(
planets=np.array([0, 1, 0, 1]),
epochs=np.random.uniform(0., 1., size=4),
times=np.random.uniform(0., 1., size=4),
rsky=np.random.uniform(-1., 1., size=4),
vsky=np.random.uniform(-1., 1., size=4),
rv=None,
)

def test_get_length(self, result_without_rv):
assert len(result_without_rv) == 4

def test_get_row_without_rv(self, result_without_rv):
keys = ['planets', 'epochs', 'times', 'rsky', 'vsky']
for i in range(4):
expected = [getattr(result_without_rv, key)[i] for key in keys]
assert np.allclose(result_without_rv.row(i), expected)

def test_get_row_with_rv(self, result_with_rv):
keys = ['planets', 'epochs', 'times', 'rsky', 'vsky', 'rv']
for i in range(4):
expected = [getattr(result_with_rv, key)[i] for key in keys]
assert np.allclose(result_with_rv.row(i), expected)

def test_get_invalid_row(self, result_without_rv):
with pytest.raises(IndexError) as exc_info:
result_without_rv.row(100)

assert 'Index 100 out of bounds for array length 4' in str(exc_info)
4 changes: 2 additions & 2 deletions testing/test_rv.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@ def test_rv_given(stellar_mass, planets, python_args):

Time, dt, Total = python_args
results = ttvfast.ttvfast(planets, stellar_mass, Time, dt, Total, rv_times=rv_times)
assert np.allclose(results['rv'], expected)
assert np.allclose(results.rv, expected)


def test_no_rv_given(stellar_mass, planets, python_args):
Time, dt, Total = python_args
results = ttvfast.ttvfast(planets, stellar_mass, Time, dt, Total)
assert results['rv'] is None
assert results.rv is None
46 changes: 42 additions & 4 deletions ttvfast/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,42 @@

"Fast TTV computation"

from collections import namedtuple
import numpy as np
from ._ttvfast import _ttvfast as _ttvfast_fn
from . import models


__all__ = ['ttvfast']

TTVFastResultBase = namedtuple('TTVFastResultBase', [
'planets', 'epochs', 'times', 'rsky', 'vsky', 'rv',
])


class TTVFastResult(TTVFastResultBase):
def __len__(self):
'''Enables the `len` function to work'''
return self.times.size

def row(self, index):
'''Return a single entry into the results array'''
if index >= len(self):
raise IndexError(
"Index {index} out of bounds for array length {length}".format(
index=index, length=len(self)))

arr = [
self.planets[index],
self.epochs[index],
self.times[index],
self.rsky[index],
self.vsky[index],
]
if self.rv is not None:
arr.append(self.rv[index])

return arr


def ttvfast(planets, stellar_mass, time, dt, total, rv_times=None):
'''
Expand All @@ -25,8 +55,16 @@ def ttvfast(planets, stellar_mass, time, dt, total, rv_times=None):
input_flag = 0

len_rv = len(rv_times) if rv_times is not None else 0
positions, rv = _ttvfast_fn(
params, dt, time, total, n_plan, input_flag, len_rv, rv_times)
return {'positions': positions, 'rv': rv}
positions, rv = _ttvfast_fn(params, dt, time, total, n_plan,
input_flag, len_rv, rv_times)

return TTVFastResult(
planets=np.array(positions[0]),
epochs=np.array(positions[1]),
times=np.array(positions[2]),
rsky=np.array(positions[3]),
vsky=np.array(positions[4]),
rv=np.array(rv) if rv else None,
)

__all__ = ['ttvfast']