diff --git a/skyfield/vectorlib.py b/skyfield/vectorlib.py index 57ee37a0b..3941a388a 100644 --- a/skyfield/vectorlib.py +++ b/skyfield/vectorlib.py @@ -1,13 +1,14 @@ """Vector functions and their composition.""" from jplephem.names import target_names as _jpl_code_name_dict -from numpy import max +from numpy import max, newaxis, expand_dims, broadcast_to, diagonal, squeeze from .constants import C_AUDAY from .descriptorlib import reify from .errors import DeprecationError -from .functions import length_of +from .functions import length_of, _reconcile from .positionlib import build_position from .timelib import Time +from .units import Distance, Velocity class VectorFunction(object): """Given a time, computes a corresponding position.""" @@ -215,8 +216,10 @@ def _at(self, t): p2, v2, another_gcrs_position, message = vf._at(t) if gcrs_position is None: # TODO: so bootleg; rework whole idea gcrs_position = another_gcrs_position - p += p2 - v += v2 + p, p2 = _reconcile(p, p2) + p = p + p2 + v, v2 = _reconcile(v, v2) + v = v + v2 return p, v, gcrs_position, message def _correct_for_light_travel_time(observer, target): @@ -237,10 +240,20 @@ def _correct_for_light_travel_time(observer, target): cposition = observer.position.au cvelocity = observer.velocity.au_per_d + cposition = expand_dims(cposition, 2) + cvelocity = expand_dims(cvelocity, 2) + tposition, tvelocity, gcrs_position, message = target._at(t) + tposition = expand_dims(tposition, 3) + tvelocity = expand_dims(tvelocity, 3) + distance = length_of(tposition - cposition) light_time0 = 0.0 + + whole = broadcast_to(whole[:, newaxis, newaxis], (len(t.tt), tposition.shape[2], cposition.shape[2])) + tdb_fraction = tdb_fraction[:, newaxis, newaxis] + for i in range(10): light_time = distance / C_AUDAY delta = light_time - light_time0 @@ -252,12 +265,12 @@ def _correct_for_light_travel_time(observer, target): # fraction, for adding to the whole and fraction of TDB. t2 = ts.tdb_jd(whole, tdb_fraction - light_time) - tposition, tvelocity, gcrs_position, message = target._at(t2) + tposition, tvelocity, gcrs_position, message = diagonal(target._at(t2), axis1=1, axis2=2) distance = length_of(tposition - cposition) light_time0 = light_time else: raise ValueError('light-travel time failed to converge') - return tposition - cposition, tvelocity - cvelocity, t, light_time + return squeeze(tposition - cposition), squeeze(tvelocity - cvelocity), t, light_time def _jpl_name(target): if not isinstance(target, int):