diff --git a/.continuous-integration/travis/setup_dependencies_common.sh b/.continuous-integration/travis/setup_dependencies_common.sh index b650d8ff2e17..946bd13fa0fe 100755 --- a/.continuous-integration/travis/setup_dependencies_common.sh +++ b/.continuous-integration/travis/setup_dependencies_common.sh @@ -41,12 +41,13 @@ then fi # DOCUMENTATION DEPENDENCIES -# build_sphinx needs sphinx and matplotlib (for plot_directive). Note that -# this matplotlib will *not* work with py 3.x, but our sphinx build is +# build_sphinx needs sphinx as well as matplotlib and wcsaxes (for plot_directive). +# Note that this matplotlib will *not* work with py 3.x, but our sphinx build is # currently 2.7, so that's fine if [[ $SETUP_CMD == build_sphinx* ]] then $CONDA_INSTALL Sphinx=1.2.2 Pygments matplotlib + pip install wcsaxes fi # COVERAGE DEPENDENCIES diff --git a/CHANGES.rst b/CHANGES.rst index 61006338f63f..8c909af83e46 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -60,6 +60,8 @@ New Features - Add a function ``get_read_trace()`` that returns a traceback of the attempted read formats for the last call to ``astropy.io.ascii.read``. [#3688] + - Supports LZMA decompression via ``get_readable_fileobj`` [#3667] + - ``astropy.io.fits`` - Support reading and writing from bzip2 compressed files. i.e. ``.fits.bz2`` @@ -118,6 +120,9 @@ New Features - Added ``cenfunc``, ``stdfunc``, and ``axis`` keywords to ``sigma_clipped_stats``. [#3792] + - ``sigma_clip`` automatically masks invalid input values (NaNs, Infs) before + performing the clipping [#4051] + - Added the ``histogram`` routine, which is similar to ``np.histogram`` but includes several additional options for automatic determination of optimal histogram bins. Associated helper routines include ``bayesian_blocks``, @@ -191,11 +196,21 @@ New Features - ``astropy.utils`` + - Added new ``OrderedDescriptor`` and ``OrderedDescriptorContainer`` utility + classes that make it easier to implement classes with declarative APIs, + wherein class-level attributes have an inherit "ordering" to them that is + specified by the order in which those attributes are defined in the class + declaration (by defining them using special descriptors that have + ``OrderedDescriptor`` as a base class). See the API documentation for + these classes for more details. Coordinate frames and models now use this + interface. [#3679] + - Added function ``dtype_info_name`` to the ``data_info`` module to provide the name of a ``dtype`` for human-readable informational purposes. [#3868] - Added ``classproperty`` decorator--this is to ``property`` as ``classmethod`` is to normal instance methods. [#3982] + - ``iers.open`` now handles network URLs, as well as local paths. [#3850] - The ``astropy.utils.wraps`` decorator now takes an optional ``exclude_args`` argument not shared by the standard library ``wraps`` @@ -205,6 +220,14 @@ New Features of the wrapper function. This is particularly useful when wrapping an instance method as a function (to exclude the ``self`` argument). [#4017] + - ``get_readable_fileobj`` can automatically decompress LZMA ('.xz') + files using the ``lzma`` module of Python 3.3+ or, when available, the + ``backports.lzma`` package on earlier versions. [#3667] + + - The ``resolve_name`` utility now accepts any number of additional + positional arguments that are automatically dotted together with the + first ``name`` argument. [#4083] + - ``astropy.visualization`` - Added the ``hist`` function, which is similar to ``plt.hist`` but @@ -440,13 +463,6 @@ Bug fixes - ``astropy.units`` - - Added frequency-equivalency check when declaring doppler equivalencies - [#3728] - - - Define ``floor_divide`` (``//``) for ``Quantity`` to be consistent - ``divmod``, such that it only works where the quotient is dimensionless. - This guarantees that ``(q1 // q2) * q2 + (q1 % q2) == q1``. [#3817] - - ``astropy.utils`` - ``astropy.visualization`` @@ -463,6 +479,8 @@ Other Changes and Additions - The version of ``PLY`` that ships with astropy has been updated to 3.6. +- WCSAxes is now required for doc builds. [#4074] + 1.0.5 (unreleased) ------------------ @@ -547,6 +565,9 @@ API Changes - ``astropy.utils`` + - ``console`` was updated to support IPython 4.x and Jupyter 1.x. + [#4078] + - ``astropy.vo`` - ``astropy.wcs`` @@ -568,6 +589,9 @@ Bug Fixes - ``astropy.io.fits`` + - Fix bug when extending one header (without comments) with another + (with comments). [#3967] + - ``astropy.io.misc`` - ``astropy.io.registry`` @@ -576,6 +600,8 @@ Bug Fixes - ``astropy.modeling`` + - Cleaned up ``repr`` of models that have no parameters. [#4076] + - ``astropy.nddata`` - ``astropy.stats`` @@ -590,6 +616,10 @@ Bug Fixes - ``astropy.utils`` + - ``resolve_name`` no longer causes ``sys.modules`` to be cluttered with + additional copies of modules under a package imported like + ``resolve_name('numpy')``. [#4084] + - ``astropy.vo`` - ``astropy.wcs`` @@ -600,7 +630,7 @@ Other Changes and Additions - Nothing changed yet. -1.0.4 (unreleased) +1.0.4 (2015-08-11) ------------------ New Features @@ -699,6 +729,13 @@ Bug Fixes - ``astropy.units`` + - Added frequency-equivalency check when declaring doppler equivalencies + [#3728] + + - Define ``floor_divide`` (``//``) for ``Quantity`` to be consistent + ``divmod``, such that it only works where the quotient is dimensionless. + This guarantees that ``(q1 // q2) * q2 + (q1 % q2) == q1``. [#3817] + - Fixed the documentation of supported units to correctly report support for SI prefixes. Previously the table of supported units incorrectly showed several derived unit as not supporting prefixes, when in fact they do. diff --git a/astropy/_erfa/core.c.templ b/astropy/_erfa/core.c.templ index 1d18ff2b2589..e7768dfa49ae 100644 --- a/astropy/_erfa/core.c.templ +++ b/astropy/_erfa/core.c.templ @@ -38,7 +38,10 @@ typedef struct { static PyObject *Py_{{ func.pyname }}(PyObject *self, PyObject *args, PyObject *kwds) { - {%- for arg in func.args_by_inout('in|inout|out|ret|stat') %} + {%- for arg in func.args_by_inout('in|inout|out') %} + {{ arg.ctype }} (*_{{ arg.name }}){{ arg.cshape }}; + {%- endfor %} + {%- for arg in func.args_by_inout('ret|stat') %} {{ arg.ctype_ptr }} _{{ arg.name }}; {%- endfor %} {%- if func.args_by_inout('stat')|length > 0 %} @@ -52,10 +55,10 @@ static PyObject *Py_{{ func.pyname }}(PyObject *self, PyObject *args, PyObject * do { {%- for arg in func.args_by_inout('in|inout|out') %} - _{{ arg.name }} = (({{ arg.ctype }} *)(dataptrarray[{{ func.args.index(arg) }}])){%- if arg.ctype_ptr[-1] != '*' %}[0]{%- endif %}; + _{{ arg.name }} = (({{ arg.ctype }} (*){{ arg.cshape }})(dataptrarray[{{ func.args.index(arg) }}])); {%- endfor %} - {{ func.args_by_inout('ret|stat')|map(attribute='name')|surround('_', ' = ')|join }}{{func.name}}({{ func.args_by_inout('in|inout|out')|map(attribute='name')|prefix('_')|join(', ') }}); + {{ func.args_by_inout('ret|stat')|map(attribute='name')|surround('_', ' = ')|join }}{{func.name}}({{ func.args_by_inout('in|inout|out')|map(attribute='name_for_call')|join(', ') }}); {%- for arg in func.args_by_inout('ret|stat') %} *(({{ arg.ctype_ptr }} *)(dataptrarray[{{ func.args.index(arg) }}])) = _{{ arg.name }}; diff --git a/astropy/_erfa/erfa_generator.py b/astropy/_erfa/erfa_generator.py index 84110afa7a17..ddfaaee13ccb 100644 --- a/astropy/_erfa/erfa_generator.py +++ b/astropy/_erfa/erfa_generator.py @@ -15,6 +15,7 @@ import re import os.path +from astropy.utils.compat.odict import OrderedDict ctype_to_dtype = {'double' : "numpy.double", @@ -181,6 +182,17 @@ def dtype(self): def ndim(self): return len(self.shape) + @property + def cshape(self): + return ''.join(['[{0}]'.format(s) for s in self.shape]) + + @property + def name_for_call(self): + if self.is_ptr: + return '_'+self.name + else: + return '*_'+self.name + def __repr__(self): return "Argument('{0}', name='{1}', ctype='{2}', inout_state='{3}')".format(self.definition, self.name, self.ctype, self.inout_state) @@ -378,7 +390,7 @@ def surround(a_list, pre, post): with open(erfahfn, "r") as f: erfa_h = f.read() - funcs = {} + funcs = OrderedDict() section_subsection_functions = re.findall('/\* (\w*)/(\w*) \*/\n(.*?)\n\n', erfa_h, flags=re.DOTALL|re.MULTILINE) for section, subsection, functions in section_subsection_functions: diff --git a/astropy/config/configuration.py b/astropy/config/configuration.py index e22b78d04629..4e3779ddb8b9 100644 --- a/astropy/config/configuration.py +++ b/astropy/config/configuration.py @@ -26,6 +26,7 @@ from ..extern.configobj import configobj, validate from ..utils.exceptions import AstropyWarning, AstropyDeprecationWarning from ..utils import find_current_module +from ..utils.introspection import resolve_name from ..utils.misc import InheritDocstrings from .paths import get_config_dir @@ -219,6 +220,11 @@ class _Conf(config.ConfigNamespace): # this is used to make validation faster so a Validator object doesn't # have to be created every time _validator = validate.Validator() + cfgtype = None + """ + A type specifier like those used as the *values* of a particular key in a + ``configspec`` file of ``configobj``. + """ def __init__(self, defaultvalue='', description=None, cfgtype=None, module=None, aliases=None): @@ -308,12 +314,14 @@ def set_temp(self, value): Sets this item to a specified value only inside a with block. Use as:: + ITEM = ConfigItem('ITEM', 'default', 'description') with ITEM.set_temp('newval'): - ... do something that wants ITEM's value to be 'newval' ... + #... do something that wants ITEM's value to be 'newval' ... + print(ITEM) - # ITEM is now 'default' after the with block + # ITEM is now 'default' after the with block Parameters ---------- @@ -565,11 +573,7 @@ def _deprecation_warning(self): AstropyDeprecationWarning) def _get_target(self): - if self._new_module not in sys.modules: - __import__(self._new_module) - mod = sys.modules[self._new_module] - cfg = getattr(mod, 'conf') - return cfg + return resolve_name(self._new_module, 'conf') def set(self, value): self._deprecation_warning() @@ -841,8 +845,7 @@ def update_default_config(pkg, default_cfg_dir_or_fn, version=None): identical = False if version is None: - mod = __import__(pkg) - version = mod.__version__ + version = resolve_name(pkg, '__version__') # Don't install template files for dev versions, or we'll end up # spamming `~/.astropy/config`. diff --git a/astropy/coordinates/baseframe.py b/astropy/coordinates/baseframe.py index b3b05fa3c9f6..eb8c9c64bf92 100644 --- a/astropy/coordinates/baseframe.py +++ b/astropy/coordinates/baseframe.py @@ -22,7 +22,7 @@ from ..extern import six from ..utils.exceptions import AstropyDeprecationWarning, AstropyWarning from .. import units as u -from ..utils import OrderedDict +from ..utils import OrderedDict, OrderedDescriptor, OrderedDescriptorContainer from .transformations import TransformGraph from .representation import (BaseRepresentation, CartesianRepresentation, SphericalRepresentation, @@ -58,7 +58,7 @@ def _get_repr_cls(value): return value -class FrameMeta(type): +class FrameMeta(OrderedDescriptorContainer): def __new__(mcls, name, bases, members): if 'default_representation' in members: default_repr = members.pop('default_representation') @@ -123,7 +123,7 @@ def getter(self): members[attr] = property(getter) -class FrameAttribute(object): +class FrameAttribute(OrderedDescriptor): """A non-mutable data descriptor to hold a frame attribute. This class must be used to define frame attributes (e.g. ``equinox`` or @@ -158,20 +158,14 @@ class FK4(BaseCoordinateFrame): ``default is None`` and no value was supplied during initialization. """ - _nextid = 1 - """ - Used to ascribe some ordering to FrameAttribute instances so that the - order they were assigned in a class body can be determined. - """ + _class_attribute_ = 'frame_attributes' + _name_attribute_ = 'name' + name = '' def __init__(self, default=None, secondary_attribute=''): self.default = default self.secondary_attribute = secondary_attribute - - # Use FrameAttribute._nextid explicitly so that subclasses of - # FrameAttribute use the same counter - self._order = FrameAttribute._nextid - FrameAttribute._nextid += 1 + super(FrameAttribute, self).__init__() def convert_input(self, value): """ @@ -207,25 +201,6 @@ def convert_input(self, value): return value, False def __get__(self, instance, frame_cls=None): - if not hasattr(self, 'name'): - # Find attribute name of self by finding this object in the frame - # class which is requesting this attribute or any of its - # superclasses. - for mro_cls in frame_cls.__mro__: - for name, val in mro_cls.__dict__.items(): - if val is self: - self.name = name - break - if hasattr(self, 'name'): # Can't nicely break out of two loops - break - else: - # Cannot think of a way to actually raise this exception. This - # instance containing this code must be in the class dict in - # order to get excecuted by attribute access. But leave this - # here just in case... - raise AttributeError( - 'Unexpected inability to locate descriptor') - out = None if instance is not None: @@ -470,6 +445,11 @@ class BaseCoordinateFrame(object): # specifies special names/units for representation attributes frame_specific_representation_info = {} + _inherit_descriptors_ = (FrameAttribute,) + + frame_attributes = OrderedDict() + # Default empty frame_attributes dict + # This __new__ provides for backward-compatibility with pre-0.4 API. # TODO: remove in 1.0 def __new__(cls, *args, **kwargs): @@ -482,24 +462,21 @@ def __new__(cls, *args, **kwargs): use_skycoord = False - if (len(args) > 1 or (len(args) == 1 and - not isinstance(args[0], BaseRepresentation))): - for arg in args: - if (not isinstance(arg, u.Quantity) - and not isinstance(arg, BaseRepresentation)): - msg = ('Initializing frame classes like "{0}" using string ' - 'or other non-Quantity arguments is deprecated, and ' - 'will be removed in the next version of Astropy. ' - 'Instead, you probably want to use the SkyCoord ' - 'class with the "frame={1}" keyword, or if you ' - 'really want to use the low-level frame classes, ' - 'create it with an Angle or Quantity.') - - warnings.warn(msg.format(cls.__name__, - cls.__name__.lower()), - AstropyDeprecationWarning) - use_skycoord = True - break + for arg in args: + if not isinstance(arg, (u.Quantity, BaseRepresentation)): + msg = ('Initializing frame classes like "{0}" using string ' + 'or other non-Quantity arguments is deprecated, and ' + 'will be removed in the next version of Astropy. ' + 'Instead, you probably want to use the SkyCoord ' + 'class with the "frame={1}" keyword, or if you ' + 'really want to use the low-level frame classes, ' + 'create it with an Angle or Quantity.') + + warnings.warn(msg.format(cls.__name__, + cls.__name__.lower()), + AstropyDeprecationWarning) + use_skycoord = True + break if 'unit' in kwargs and not use_skycoord: warnings.warn( @@ -643,20 +620,8 @@ def isscalar(self): @classmethod def get_frame_attr_names(cls): - seen = set() - attributes = [] - for mro_cls in cls.__mro__: - for name, val in mro_cls.__dict__.items(): - if isinstance(val, FrameAttribute) and name not in seen: - seen.add(name) - # Add the sort order, name, and actual value of the frame - # attribute in question - attributes.append((val._order, name, - getattr(mro_cls, name))) - - # Sort by the frame attribute order - attributes.sort(key=lambda a: a[0]) - return OrderedDict((a[1], a[2]) for a in attributes) + return OrderedDict((name, getattr(cls, name)) + for name in cls.frame_attributes) @property def representation(self): @@ -1062,10 +1027,10 @@ def __getattr__(self, attr): return val def __setattr__(self, attr, value): - repr_attr_names = [] + repr_attr_names = set() if hasattr(self, 'representation_info'): for representation_attr in self.representation_info.values(): - repr_attr_names.extend(representation_attr['names']) + repr_attr_names.update(representation_attr['names']) if attr in repr_attr_names: raise AttributeError( 'Cannot set any frame attribute {0}'.format(attr)) @@ -1181,17 +1146,16 @@ class GenericFrame(BaseCoordinateFrame): A dictionary of attributes to be used as the frame attributes for this frame. """ + name = None # it's not a "real" frame so it doesn't have a name def __init__(self, frame_attrs): - super(GenericFrame, self).__setattr__('_frame_attr_names', frame_attrs) - super(GenericFrame, self).__init__(None) - - for attrnm, attrval in frame_attrs.items(): - setattr(self, '_' + attrnm, attrval) + self.frame_attributes = OrderedDict() + for name, default in frame_attrs.items(): + self.frame_attributes[name] = FrameAttribute(default) + setattr(self, '_' + name, default) - def get_frame_attr_names(self): - return self._frame_attr_names + super(GenericFrame, self).__init__(None) def __getattr__(self, name): if '_' + name in self.__dict__: @@ -1200,7 +1164,7 @@ def __getattr__(self, name): raise AttributeError('no {0}'.format(name)) def __setattr__(self, name, value): - if name in self._frame_attr_names: + if name in self.get_frame_attr_names(): raise AttributeError("can't set frame attribute '{0}'".format(name)) else: super(GenericFrame, self).__setattr__(name, value) diff --git a/astropy/coordinates/builtin_frames/utils.py b/astropy/coordinates/builtin_frames/utils.py index 9fa378173e40..97177ccde0e0 100644 --- a/astropy/coordinates/builtin_frames/utils.py +++ b/astropy/coordinates/builtin_frames/utils.py @@ -33,11 +33,10 @@ _IERS_HINT = """ If you need enough precision such that this matters (~<10 arcsec), you can -download the latest IERS predictions by running: +use the latest IERS predictions by running: - >>> from astropy.utils.data import download_file >>> from astropy.utils import iers - >>> iers.IERS.iers_table = iers.IERS_A.open(download_file(iers.IERS_A_URL, cache=True)) + >>> iers.IERS.iers_table = iers.IERS_A.open(iers.IERS_A_URL) """ diff --git a/astropy/coordinates/tests/test_frames.py b/astropy/coordinates/tests/test_frames.py index 410ed7f44b8b..fc1398e3d1f3 100644 --- a/astropy/coordinates/tests/test_frames.py +++ b/astropy/coordinates/tests/test_frames.py @@ -7,8 +7,10 @@ import numpy as np from ... import units as u +from ...extern import six from ...tests.helper import (pytest, quantity_allclose as allclose, assert_quantity_allclose as assert_allclose) +from ...utils import OrderedDescriptorContainer from .. import representation NUMPY_LT_1P7 = [int(x) for x in np.__version__.split('.')[:2]] < [1, 7] @@ -17,6 +19,7 @@ def test_frame_attribute_descriptor(): """ Unit tests of the FrameAttribute descriptor """ from ..baseframe import FrameAttribute + @six.add_metaclass(OrderedDescriptorContainer) class TestFrameAttributes(object): attr_none = FrameAttribute() attr_2 = FrameAttribute(default=2) diff --git a/astropy/io/ascii/setup_package.py b/astropy/io/ascii/setup_package.py index 982c084fce91..6a8c1d778b5d 100644 --- a/astropy/io/ascii/setup_package.py +++ b/astropy/io/ascii/setup_package.py @@ -48,6 +48,7 @@ def get_package_data(): 't/html2.html', 't/ipac.dat', 't/ipac.dat.bz2', + 't/ipac.dat.xz', 't/latex1.tex', 't/latex1.tex.gz', 't/latex2.tex', @@ -61,6 +62,7 @@ def get_package_data(): 't/short.rdb', 't/short.rdb.bz2', 't/short.rdb.gz', + 't/short.rdb.xz', 't/short.tab', 't/simple.txt', 't/simple2.txt', diff --git a/astropy/io/ascii/tests/t/ipac.dat.xz b/astropy/io/ascii/tests/t/ipac.dat.xz new file mode 100644 index 000000000000..cf06f78fef00 Binary files /dev/null and b/astropy/io/ascii/tests/t/ipac.dat.xz differ diff --git a/astropy/io/ascii/tests/t/short.rdb.xz b/astropy/io/ascii/tests/t/short.rdb.xz new file mode 100644 index 000000000000..93faba985344 Binary files /dev/null and b/astropy/io/ascii/tests/t/short.rdb.xz differ diff --git a/astropy/io/ascii/tests/test_compressed.py b/astropy/io/ascii/tests/test_compressed.py index b5442be8b102..c7b56104da4d 100644 --- a/astropy/io/ascii/tests/test_compressed.py +++ b/astropy/io/ascii/tests/test_compressed.py @@ -1,5 +1,6 @@ # Licensed under a 3-clause BSD style license - see LICENSE.rst import os +import sys import numpy as np from ....tests.helper import pytest @@ -15,6 +16,16 @@ else: HAS_BZ2 = True +try: + if sys.version_info >= (3,3,0): + import lzma + else: + from backports import lzma +except ImportError: + HAS_XZ = False +else: + HAS_XZ = True + @pytest.mark.parametrize('filename', ['t/daophot.dat.gz', 't/latex1.tex.gz', 't/short.rdb.gz']) @@ -32,3 +43,12 @@ def test_bzip2(filename): t_uncomp = read(os.path.join(ROOT, filename.replace('.bz2', ''))) assert t_comp.dtype.names == t_uncomp.dtype.names assert np.all(t_comp.as_array() == t_uncomp.as_array()) + + +@pytest.mark.xfail('not HAS_XZ') +@pytest.mark.parametrize('filename', ['t/short.rdb.xz', 't/ipac.dat.xz']) +def test_xz(filename): + t_comp = read(os.path.join(ROOT, filename)) + t_uncomp = read(os.path.join(ROOT, filename.replace('.xz', ''))) + assert t_comp.dtype.names == t_uncomp.dtype.names + assert np.all(t_comp.as_array() == t_uncomp.as_array()) diff --git a/astropy/io/ascii/tests/test_read.py b/astropy/io/ascii/tests/test_read.py index 79eda7583673..1ba44703c5ed 100644 --- a/astropy/io/ascii/tests/test_read.py +++ b/astropy/io/ascii/tests/test_read.py @@ -18,6 +18,13 @@ from .. import core from ..ui import _probably_html, get_read_trace +try: + import bz2 +except ImportError: + HAS_BZ2 = False +else: + HAS_BZ2 = True + @pytest.mark.parametrize('fast_reader', [True, False, 'force']) def test_convert_overflow(fast_reader): @@ -887,6 +894,7 @@ def test_guess_fail(): assert 'Number of header columns (1) inconsistent with data columns (3)' in str(err.value) +@pytest.mark.xfail('not HAS_BZ2') def test_guessing_file_object(): """ Test guessing a file object. Fixes #3013 and similar issue noted in #3019. diff --git a/astropy/io/ascii/ui.py b/astropy/io/ascii/ui.py index 7bc99d03aa0a..795d24a36dd8 100644 --- a/astropy/io/ascii/ui.py +++ b/astropy/io/ascii/ui.py @@ -285,6 +285,8 @@ def read(table, guess=None, **kwargs): try: with get_readable_fileobj(table) as fileobj: table = fileobj.read() + except ValueError: # unreadable or invalid binary file + raise except: pass else: diff --git a/astropy/io/fits/header.py b/astropy/io/fits/header.py index 82c55863f9c9..880d5193430f 100644 --- a/astropy/io/fits/header.py +++ b/astropy/io/fits/header.py @@ -1359,7 +1359,7 @@ def extend(self, cards, strip=True, unique=False, update=False, else: extend_cards.append(card) else: - if unique or update and keyword in self: + if (unique or update) and keyword in self: if card.is_blank: extend_cards.append(card) continue diff --git a/astropy/io/fits/tests/test_header.py b/astropy/io/fits/tests/test_header.py index 754039804985..6c72879a0adc 100644 --- a/astropy/io/fits/tests/test_header.py +++ b/astropy/io/fits/tests/test_header.py @@ -1184,10 +1184,27 @@ def test_header_extend_unique(self): hdu = fits.PrimaryHDU() hdu2 = fits.ImageHDU() hdu.header['MYKEY'] = ('some val', 'some comment') + hdu2.header['MYKEY'] = ('some other val', 'some other comment') hdu.header.extend(hdu2.header, unique=True) assert len(hdu.header) == 5 assert hdu.header[-1] == 'some val' + def test_header_extend_unique_commentary(self): + """ + Test extending header with and without unique=True and commentary + cards in the header being added. Issue astropy/astropy#3967 + """ + for commentary_card in ['', 'COMMENT', 'HISTORY']: + for is_unique in [True, False]: + hdu = fits.PrimaryHDU() + # Make sure we are testing the case we want. + assert commentary_card not in hdu.header + hdu2 = fits.ImageHDU() + hdu2.header[commentary_card] = 'My text' + hdu.header.extend(hdu2.header, unique=is_unique) + assert len(hdu.header) == 5 + assert hdu.header[commentary_card][0] == 'My text' + def test_header_extend_update(self): """ Test extending the header with and without update=True. @@ -1219,6 +1236,25 @@ def test_header_extend_update(self): assert len(hdu.header['HISTORY']) == 2 assert hdu.header[-1] == 'history 2' + def test_header_extend_update_commentary(self): + """ + Test extending header with and without unique=True and commentary + cards in the header being added. + + Though not quite the same as astropy/astropy#3967, update=True hits + the same if statement as that issue. + """ + for commentary_card in ['', 'COMMENT', 'HISTORY']: + for is_update in [True, False]: + hdu = fits.PrimaryHDU() + # Make sure we are testing the case we want. + assert commentary_card not in hdu.header + hdu2 = fits.ImageHDU() + hdu2.header[commentary_card] = 'My text' + hdu.header.extend(hdu2.header, update=is_update) + assert len(hdu.header) == 5 + assert hdu.header[commentary_card][0] == 'My text' + def test_header_extend_exact(self): """ Test that extending an empty header with the contents of an existing diff --git a/astropy/io/misc/tests/test_hdf5.py b/astropy/io/misc/tests/test_hdf5.py index 19c8cf0f0f35..3433ab1f9fbe 100644 --- a/astropy/io/misc/tests/test_hdf5.py +++ b/astropy/io/misc/tests/test_hdf5.py @@ -415,3 +415,5 @@ def test_read_h5py_objects(tmpdir): t4 = Table.read(f['the_table']) assert np.all(t4['a'] == [1, 2, 3]) + + f.close() # don't raise an error in 'test --open-files' diff --git a/astropy/modeling/core.py b/astropy/modeling/core.py index 9d4e5bd3b905..3c163e3e3d0c 100644 --- a/astropy/modeling/core.py +++ b/astropy/modeling/core.py @@ -34,9 +34,10 @@ from ..extern.six.moves import copyreg from ..table import Table from ..utils import (deprecated, sharedmethod, find_current_module, - InheritDocstrings) + InheritDocstrings, OrderedDescriptorContainer) from ..utils.codegen import make_function_with_signature from ..utils.compat import ignored +from ..utils.compat.odict import OrderedDict from ..utils.exceptions import AstropyDeprecationWarning from .utils import (array_repr_oneline, check_broadcast, combine_labels, make_binary_operator_eval, ExpressionTree, @@ -47,7 +48,7 @@ __all__ = ['Model', 'FittableModel', 'Fittable1DModel', 'Fittable2DModel', - 'custom_model', 'ModelDefinitionError', 'render_model'] + 'custom_model', 'ModelDefinitionError'] class ModelDefinitionError(TypeError): @@ -73,7 +74,7 @@ def _model_oper(oper, **kwargs): left, right, **kwargs) -class _ModelMeta(InheritDocstrings, abc.ABCMeta): +class _ModelMeta(OrderedDescriptorContainer, InheritDocstrings, abc.ABCMeta): """ Metaclass for Model. @@ -99,23 +100,37 @@ class _ModelMeta(InheritDocstrings, abc.ABCMeta): creating them. """ + # Default empty dict for _parameters_, which will be empty on model + # classes that don't have any Parameters + _parameters_ = OrderedDict() + def __new__(mcls, name, bases, members): # See the docstring for _is_dynamic above if '_is_dynamic' not in members: members['_is_dynamic'] = mcls._is_dynamic - parameters = mcls._handle_parameters(name, members) - mcls._create_inverse_property(members) - mcls._handle_backwards_compat(name, members) + return super(_ModelMeta, mcls).__new__(mcls, name, bases, members) + + def __init__(cls, name, bases, members): + # Make sure OrderedDescriptorContainer gets to run before doing + # anything else + super(_ModelMeta, cls).__init__(name, bases, members) - cls = super(_ModelMeta, mcls).__new__(mcls, name, bases, members) + if cls._parameters_: + if hasattr(cls, '_param_names'): + # Slight kludge to support compound models, where + # cls.param_names is a property; could be improved with a + # little refactoring but fine for now + cls._param_names = tuple(cls._parameters_) + else: + cls.param_names = tuple(cls._parameters_) - mcls._handle_special_methods(members, cls, parameters) + cls._create_inverse_property(members) + cls._handle_backwards_compat(name, members) + cls._handle_special_methods(members) if not inspect.isabstract(cls) and not name.startswith('_'): - mcls.registry.add(cls) - - return cls + cls.registry.add(cls) def __repr__(cls): """ @@ -229,58 +244,8 @@ def rename(cls, name): return new_cls - @classmethod - def _handle_parameters(mcls, name, members): - # Handle parameters - param_names = members.get('param_names', ()) - parameters = {} - for key, value in members.items(): - if not isinstance(value, Parameter): - continue - if not value.name: - # Name not explicitly given in the constructor; add the name - # automatically via the attribute name - value._name = key - value._attr = '_' + key - if value.name != key: - raise ModelDefinitionError( - "Parameters must be defined with the same name as the " - "class attribute they are assigned to. Parameters may " - "take their name from the class attribute automatically " - "if the name argument is not given when initializing " - "them.") - parameters[value.name] = value - - # If no parameters were defined get out early--this is especially - # important for PolynomialModels which take a different approach to - # parameters, since they can have a variable number of them - if parameters: - mcls._check_parameters(name, members, param_names, parameters) - - return parameters - - @staticmethod - def _check_parameters(name, members, param_names, parameters): - # If param_names was declared explicitly we use only the parameters - # listed manually in param_names, but still check that all listed - # parameters were declared - if param_names and isiterable(param_names): - for param_name in param_names: - if param_name not in parameters: - raise RuntimeError( - "Parameter {0!r} listed in {1}.param_names was not " - "declared in the class body.".format(param_name, name)) - else: - param_names = tuple(param.name for param in - sorted(parameters.values(), - key=lambda p: p._order)) - members['param_names'] = param_names - members['_param_orders'] = \ - dict((name, idx) for idx, name in enumerate(param_names)) - - @staticmethod - def _create_inverse_property(members): - inverse = members.get('inverse', None) + def _create_inverse_property(cls, members): + inverse = members.get('inverse') if inverse is None: return @@ -307,11 +272,9 @@ def fset(self, value): self._custom_inverse = value - members['inverse'] = property(wrapped_fget, fset, - doc=inverse.__doc__) + cls.inverse = property(wrapped_fget, fset, doc=inverse.__doc__) - @classmethod - def _handle_backwards_compat(mcls, name, members): + def _handle_backwards_compat(cls, name, members): # Backwards compatibility check for 'eval' -> 'evaluate' # TODO: Remove sometime after Astropy 1.0 release. if 'eval' in members and 'evaluate' not in members: @@ -320,7 +283,7 @@ def _handle_backwards_compat(mcls, name, members): "FittableModel is deprecated; please rename this method to " "'evaluate'. Otherwise its semantics remain the same.", AstropyDeprecationWarning) - members['evaluate'] = members['eval'] + cls.evaluate = members['eval'] elif ('evaluate' in members and callable(members['evaluate']) and not getattr(members['evaluate'], '__isabstractmethod__', False)): @@ -329,10 +292,9 @@ def _handle_backwards_compat(mcls, name, members): # abstractmethod as well alt = '.'.join((name, 'evaluate')) deprecate = deprecated('1.0', alternative=alt, name='eval') - members['eval'] = deprecate(members['evaluate']) + cls.eval = deprecate(members['evaluate']) - @classmethod - def _handle_special_methods(mcls, members, cls, parameters): + def _handle_special_methods(cls, members): # Handle init creation from inputs def update_wrapper(wrapper, cls): # Set up the new __call__'s metadata attributes as though it were @@ -363,14 +325,14 @@ def __call__(self, *inputs, **kwargs): cls.__call__ = new_call if ('__init__' not in members and not inspect.isabstract(cls) and - parameters): + cls._parameters_): # If *all* the parameters have default values we can make them # keyword arguments; otherwise they must all be positional # arguments if all(p.default is not None - for p in six.itervalues(parameters)): + for p in six.itervalues(cls._parameters_)): args = ('self',) - kwargs = [(name, parameters[name].default) + kwargs = [(name, cls._parameters_[name].default) for name in cls.param_names] else: args = ('self',) + cls.param_names @@ -588,10 +550,7 @@ class Model(object): # it to. _custom_inverse = None - # If a bounding_box_default function is defined in the model, - # then the _bounding_box attribute should be set to 'auto' in the model. - # Otherwise, the default is None for no bounding box. - _bounding_box = None + _bounding_box = 'auto' # Default n_models attribute, so that __len__ is still defined even when a # model hasn't completed initialization yet @@ -789,8 +748,8 @@ def ineqcons(self): @property def inverse(self): """ - Returns a new `Model` instance which performs the inverse - transform, if an analytic inverse is defined for this model. + Returns a new `~astropy.modeling.Model` instance which performs the + inverse transform, if an analytic inverse is defined for this model. Even on models that don't have an inverse defined, this property can be set with a manually-defined inverse, such a pre-computed or @@ -798,49 +757,56 @@ def inverse(self): `~astropy.modeling.polynomial.PolynomialModel`, but not by requirement). - Note to authors of `Model` subclasses: To define an inverse for a - model simply override this property to return the appropriate model - representing the inverse. The machinery that will make the inverse - manually-overridable is added automatically by the base class. + Note to authors of `~astropy.modeling.Model` subclasses: To define an + inverse for a model simply override this property to return the + appropriate model representing the inverse. The machinery that will + make the inverse manually-overridable is added automatically by the + base class. """ raise NotImplementedError("An analytical inverse transform has not " "been implemented for this model.") + def bounding_box_default(self): + """ + Raises a ``NotImplementedError`` by default. This is overridden by defining + `bounding_box_default` in the subclass. + """ + raise NotImplementedError("The bounding box is not set for this model.") + @property def bounding_box(self): """ A `tuple` of length `n_inputs` defining the bounding box limits, or `None` for no bounding box. - The default is `None`, unless ``bounding_box_default`` is defined. - `bounding_box` can be set manually to an array-like object of shape - ``(model.n_inputs, 2)``. For further usage, including how to set the - ``bounding_box_default``, see :ref:`bounding-boxes` + The default limits are given by the `bounding_box_default` method, which + in some cases are `None`. `bounding_box` can be set manually to an + array-like object of shape ``(model.n_inputs, 2)``. For further usage, + including how to set `bounding_box_default`, see :ref:`bounding-boxes` The limits are ordered according to the `numpy` indexing convention, and are the reverse of the model input order, - e.g. for inputs ``('x', 'y', 'z')`` the ``bounding_box`` is defined: - + e.g. for inputs ``('x', 'y', 'z')``, `bounding_box` is defined: + * for 1D: ``(x_low, x_high)`` * for 2D: ``((y_low, y_high), (x_low, x_high))`` * for 3D: ``((z_low, z_high), (y_low, y_high), (x_low, x_high))`` Examples -------- - Setting the bounding boxes for a 1D, 2D, and custom 3D model. - >>> from astropy.modeling.models import Gaussian1D, Gaussian2D, custom_model + Setting the `bounding_box` limits for a 1D and 2D model. + + >>> from astropy.modeling.models import Gaussian1D, Gaussian2D >>> model_1d = Gaussian1D() >>> model_2d = Gaussian2D(x_stddev=1, y_stddev=1) - - Set the bounding box like: - >>> model_1d.bounding_box = (-5, 5) >>> model_2d.bounding_box = ((-6, 6), (-5, 5)) - For a user-defined 3D model: + Setting the bounding_box limits for a user-defined 3D `custom_model`: + >>> from astropy.modeling.models import custom_model >>> def const3d(x, y, z, amp=1): ... return amp ... @@ -848,21 +814,23 @@ def bounding_box(self): >>> model_3d = Const3D() >>> model_3d.bounding_box = ((-6, 6), (-5, 5), (-4, 4)) - To reset the default: + To reset `bounding_box` to its default limits: >>> model_1d.bounding_box = 'auto' - To turn off the bounding box: + To turn off or unset `bounding_box`: >>> model_1d.bounding_box = None - """ - if self._bounding_box == 'auto': - return self.bounding_box_default() + if self._bounding_box is None: + raise NotImplementedError("No bounding box is set for this model.") + + elif self._bounding_box =='auto': + return self.bounding_box_default() else: - return self._bounding_box + return self._bounding_box @bounding_box.setter def bounding_box(self, limits): @@ -870,12 +838,7 @@ def bounding_box(self, limits): Assigns the bounding box limits. """ - if limits == 'auto': - if not hasattr(self, 'bounding_box_default'): - warnings.warn('The default for this model is None.') - limits = None - - elif limits is None: + if limits in ('auto', None): pass else: @@ -890,7 +853,7 @@ def bounding_box(self, limits): limits = tuple([tuple(lim) for lim in limits]) except AssertionError: - raise AssertionError('If not \'auto\' or None, bounding_box must be ' + raise ValueError('If not \'auto\' or None, bounding_box must be ' 'array-like of shape ``(model.n_inputs, 2)``.') self._bounding_box = limits @@ -899,6 +862,104 @@ def bounding_box(self, limits): def evaluate(self, *args, **kwargs): """Evaluate the model on some input variables.""" + def render(self, out=None, coords=None): + """ + Evaluates a model on an input array. Evaluation is limited to + a bounding box if the `Model.bounding_box` attribute is set. + + Parameters + ---------- + out : `numpy.ndarray`, optional + The array on which the model is to be evaluated. + coords : array-like, optional + Coordinate arrays mapping to ``arr``, such that + ``arr[coords] == arr``. + + Returns + ------- + out : `numpy.ndarray` + The model evaluated on the input array if given, or else a new array from + ``coords``. + If ``out`` and ``coords`` are both `None`, the returned array is + limited to the `Model.bounding_box` limits. If + `Model.bounding_box` is `None`, ``arr`` or ``coords`` must be passed. + + Examples + -------- + :ref:`bounding-boxes` + """ + + try: + bbox = self.bounding_box + except NotImplementedError: + bbox = None + + ndim = self.n_inputs + + if (coords is None) and (out is None) and (bbox is None): + raise ValueError('If no bounding_box is set, ' + 'coords or out must be input.') + + # for consistent indexing + if ndim == 1: + if coords is not None: + coords = [coords] + if bbox is not None: + bbox = [bbox] + + if coords is not None: + # Check dimensions match out and model + assert len(coords) == ndim + if out is not None: + assert coords[0].shape == out.shape + else: + out = np.zeros(coords[0].shape) + + if out is not None: + try: + assert out.ndim == ndim + except AssertionError: + raise AssertionError('The array and model must have the same number ' + 'of dimensions.') + + if bbox is not None: + + # assures position is at center pixel, important when using add_array + pd = pos, delta = np.array([(np.mean(bb), np.ceil((bb[1] - bb[0]) / 2)) + for bb in bbox]).astype(int).T + + if coords is not None: + sub_shape = tuple(delta * 2 + 1) + sub_coords = np.array([extract_array(c, sub_shape, pos) + for c in coords]) + else: + limits = [slice(p - d, p + d + 1, 1) for p, d in pd.T] + sub_coords = np.mgrid[limits] + + sub_coords = sub_coords[::-1] + + if out is None: + out = self(*sub_coords) + else: + try: + out = add_array(out, self(*sub_coords), pos) + except ValueError: + raise ValueError('The `bounding_box` is larger than the input' + ' out in one or more dimensions. Set ' + '`model.bounding_box = None`.') + else: + + if coords is None: + im_shape = out.shape + limits = [slice(i) for i in im_shape] + coords = np.mgrid[limits] + + coords = coords[::-1] + + out += self(*coords) + + return out + def prepare_inputs(self, *inputs, **kwargs): """ This method is used in `~astropy.modeling.Model.__call__` to ensure @@ -1447,13 +1508,12 @@ def _format_str(self, keywords=[]): columns = [getattr(self, name).value for name in self.param_names] - param_table = Table(columns, names=self.param_names) - - parts.append(indent(str(param_table), width=4)) + if columns: + param_table = Table(columns, names=self.param_names) + parts.append(indent(str(param_table), width=4)) return '\n'.join(parts) - class FittableModel(Model): """ Base class for models that can be fitted using the built-in fitting @@ -2408,96 +2468,6 @@ def _custom_model_wrapper(func, fit_deriv=None): return type(model_name, (FittableModel,), members) - -def render_model(model, arr=None, coords=None): - """ - Evaluates a model on an input array. Evaluation is limited to - a bounding box if the `Model.bounding_box` attribute is set. - - Parameters - ---------- - model : `Model` - Model to be evaluated. - arr : `numpy.ndarray`, optional - Array on which the model is evaluated. - coords : array-like, optional - Coordinate arrays mapping to ``arr``, such that - ``arr[coords] == arr``. - - Returns - ------- - array : `numpy.ndarray` - The model evaluated on the input ``arr`` or a new array from ``coords``. - If ``arr`` and ``coords`` are both `None`, the returned array is - limited to the `Model.bounding_box` limits. If - `Model.bounding_box` is `None`, ``arr`` or ``coords`` must be passed. - - Examples - -------- - :ref:`bounding-boxes` - """ - - bbox = model.bounding_box - - if (coords is None) & (arr is None) & (bbox is None): - raise AssertionError('If no bounding_box is set, coords or arr must be input.') - - # for consistent indexing - if model.n_inputs == 1: - if coords is not None: - coords = [coords] - if bbox is not None: - bbox = [bbox] - - if arr is not None: - arr = arr.copy() - # Check dimensions match model - assert arr.ndim == model.n_inputs - - if coords is not None: - # Check dimensions match arr and model - coords = np.array(coords) - assert len(coords) == model.n_inputs - if arr is not None: - assert coords[0].shape == arr.shape - else: - arr = np.zeros(coords[0].shape) - - if bbox is not None: - # assures position is at center pixel, important when using add_array - pd = pos, delta = np.array([(np.mean(bb), np.ceil((bb[1] - bb[0]) / 2)) - for bb in bbox]).astype(int).T - - if coords is not None: - sub_shape = tuple(delta * 2 + 1) - sub_coords = np.array([extract_array(c, sub_shape, pos) for c in coords]) - else: - limits = [slice(p - d, p + d + 1, 1) for p, d in pd.T] - sub_coords = np.mgrid[limits] - - sub_coords = sub_coords[::-1] - - if arr is None: - arr = model(*sub_coords) - else: - try: - arr = add_array(arr, model(*sub_coords), pos) - except ValueError: - raise ValueError('The `bounding_box` is larger than the input' - ' arr in one or more dimensions. Set ' - '`model.bounding_box = None`.') - else: - - if coords is None: - im_shape = arr.shape - limits = [slice(i) for i in im_shape] - coords = np.mgrid[limits] - - arr += model(*coords[::-1]) - - return arr - - def _prepare_inputs_single_model(model, params, inputs, **kwargs): broadcasts = [] diff --git a/astropy/modeling/functional_models.py b/astropy/modeling/functional_models.py index 332760a41dfd..782de2911ae9 100644 --- a/astropy/modeling/functional_models.py +++ b/astropy/modeling/functional_models.py @@ -92,7 +92,6 @@ class Gaussian1D(Fittable1DModel): amplitude = Parameter(default=1) mean = Parameter(default=0) stddev = Parameter(default=1) - _bounding_box = 'auto' def bounding_box_default(self, factor=5.5): """ @@ -103,7 +102,7 @@ def bounding_box_default(self, factor=5.5): ---------- factor : float The multiple of `stddev` used to define the limits. - The default is 5.5-sigma, corresponding to a relative error < 1e-7. + The default is 5.5, corresponding to a relative error < 1e-7. Examples -------- @@ -112,7 +111,7 @@ def bounding_box_default(self, factor=5.5): >>> model.bounding_box (-11.0, 11.0) - This range can be set directly (see: `astropy.modeling.Model.bounding_box`) or by + This range can be set directly (see: ``help(model.bounding_box)``) or by using a different factor, like: >>> model.bounding_box = model.bounding_box_default(factor=2) @@ -173,6 +172,37 @@ class GaussianAbsorption1D(Fittable1DModel): mean = Parameter(default=0) stddev = Parameter(default=1) + def bounding_box_default(self, factor=5.5): + """ + Tuple defining the default ``bounding_box`` limits, + ``(x_low, x_high)`` + + Parameters + ---------- + factor : float + The multiple of `stddev` used to define the limits. + The default is 5.5, corresponding to a relative error < 1e-7. + + Examples + -------- + >>> from astropy.modeling.models import Gaussian1D + >>> model = Gaussian1D(mean=0, stddev=2) + >>> model.bounding_box + (-11.0, 11.0) + + This range can be set directly (see: ``help(model.bounding_box)``) or by + using a different factor, like: + + >>> model.bounding_box = model.bounding_box_default(factor=2) + >>> model.bounding_box + (-4.0, 4.0) + """ + + x0 = self.mean.value + dx = factor * self.stddev + + return (x0 - dx, x0 + dx) + @staticmethod def evaluate(x, amplitude, mean, stddev): """ @@ -273,7 +303,6 @@ class Gaussian2D(Fittable2DModel): x_stddev = Parameter(default=1) y_stddev = Parameter(default=1) theta = Parameter(default=0) - _bounding_box = 'auto' def __init__(self, amplitude=amplitude.default, x_mean=x_mean.default, y_mean=y_mean.default, x_stddev=None, y_stddev=None, @@ -336,7 +365,7 @@ def bounding_box_default(self, factor=5.5): >>> model.bounding_box ((-11.0, 11.0), (-5.5, 5.5)) - This range can be set directly (see: `astropy.modeling.Model.bounding_box`) or by + This range can be set directly (see: ``help(model.bounding_box)``) or by using a different factor like: >>> model.bounding_box = model.bounding_box_default(factor=2) @@ -442,7 +471,6 @@ def inverse(self): def evaluate(x, offset): return x + offset - class Scale(Model): """ Multiply a model by a factor. @@ -469,7 +497,6 @@ def inverse(self): def evaluate(x, factor): return factor * x - class Redshift(Fittable1DModel): """ One dimensional redshift model. @@ -509,7 +536,6 @@ def inverse(self): inv.z = 1.0 / (1.0 + self.z) - 1.0 return inv - class Sersic1D(Fittable1DModel): r""" One dimensional Sersic surface brightness profile. @@ -595,7 +621,6 @@ def evaluate(cls, r, amplitude, r_eff, n): """One dimensional Sersic profile function.""" return amplitude * np.exp(-cls._gammaincinv(2 * n, 0.5) * ((r / r_eff) ** (1 / n) - 1)) - class Sine1D(Fittable1DModel): """ One dimensional Sine model. @@ -642,7 +667,6 @@ def fit_deriv(x, amplitude, frequency, phase): np.cos(2 * np.pi * frequency * x + 2 * np.pi * phase)) return [d_amplitude, d_frequency, d_phase] - class Linear1D(Fittable1DModel): """ One dimensional Line model. @@ -684,7 +708,6 @@ def fit_deriv(x, slope, intercept): d_intercept = np.ones_like(x) return [d_slope, d_intercept] - class Lorentz1D(Fittable1DModel): """ One dimensional Lorentzian model. @@ -732,7 +755,6 @@ def fit_deriv(x, amplitude, x_0, fwhm): d_fwhm = 2 * amplitude * d_amplitude / fwhm * (1 - d_amplitude) return [d_amplitude, d_x_0, d_fwhm] - class Voigt1D(Fittable1DModel): """ One dimensional model for the Voigt profile. @@ -824,7 +846,6 @@ def fit_deriv(cls, x, x_0, amplitude_L, fwhm_L, fwhm_G): -constant * (V + (sqrt_ln2 / fwhm_G) * (2 * (x - x_0) * dVdx + fwhm_L * dVdy)) / fwhm_G] return dyda - class Const1D(Fittable1DModel): """ One dimensional Constant model. @@ -870,7 +891,6 @@ def fit_deriv(x, amplitude): d_amplitude = np.ones_like(x) return [d_amplitude] - class Const2D(Fittable2DModel): """ Two dimensional Constant model. @@ -909,7 +929,6 @@ def evaluate(x, y, amplitude): return x - class Ellipse2D(Fittable2DModel): """ A 2D Ellipse model. @@ -986,26 +1005,6 @@ class Ellipse2D(Fittable2DModel): a = Parameter(default=1) b = Parameter(default=1) theta = Parameter(default=0) - _bounding_box = 'auto' - - def bounding_box_default(self): - """ - Tuple defining the default ``bounding_box`` limits around the ellipse. - - ``((y_low, y_high), (x_low, x_high))`` - - References - ---------- - `astropy.modeling.Model.bounding_box` - """ - - a = self.a - b = self.b - theta = self.theta.value - dx, dy = ellipse_extent(a, b, theta) - - return ((self.y_0 - dy, self.y_0 + dy), - (self.x_0 - dx, self.x_0 + dx)) @staticmethod def evaluate(x, y, amplitude, x_0, y_0, a, b, theta): @@ -1020,6 +1019,20 @@ def evaluate(x, y, amplitude, x_0, y_0, a, b, theta): in_ellipse = (((numerator1 / a) ** 2 + (numerator2 / b) ** 2) <= 1.) return np.select([in_ellipse], [amplitude]) + def bounding_box_default(self): + """ + Tuple defining the default ``bounding_box`` limits. + + ``((y_low, y_high), (x_low, x_high))`` + """ + + a = self.a + b = self.b + theta = self.theta.value + dx, dy = ellipse_extent(a, b, theta) + + return ((self.y_0 - dy, self.y_0 + dy), + (self.x_0 - dx, self.x_0 + dx)) class Disk2D(Fittable2DModel): """ @@ -1067,6 +1080,16 @@ def evaluate(x, y, amplitude, x_0, y_0, R_0): return np.select([rr <= R_0 ** 2], [amplitude]) + def bounding_box_default(self): + """ + Tuple defining the default ``bounding_box`` limits. + + ``((y_low, y_high), (x_low, x_high))`` + """ + + return ((self.y_0 - self.R_0, self.y_0 + self.R_0), + (self.x_0 - self.R_0, self.x_0 + self.R_0)) + class Ring2D(Fittable2DModel): """ Two dimensional radial symmetric Ring model. @@ -1135,6 +1158,17 @@ def evaluate(x, y, amplitude, x_0, y_0, r_in, width): r_range = np.logical_and(rr >= r_in ** 2, rr <= (r_in + width) ** 2) return np.select([r_range], [amplitude]) + def bounding_box_default(self): + """ + Tuple defining the default ``bounding_box``. + + ``((y_low, y_high), (x_low, x_high))`` + """ + + dr = self.r_in + self.width + + return ((self.y_0 - dr, self.y_0 + dr), + (self.x_0 - dr, self.x_0 + dr)) class Delta1D(Fittable1DModel): """One dimensional Dirac delta function.""" @@ -1202,6 +1236,16 @@ def fit_deriv(cls, x, amplitude, x_0, width): d_width = np.zeros_like(x) return [d_amplitude, d_x_0, d_width] + def bounding_box_default(self): + """ + Tuple defining the default ``bounding_box`` limits. + + ``(x_low, x_high))`` + """ + + dx = self.width / 2 + + return (self.x_0 - dx, self.x_0 + dx) class Box2D(Fittable2DModel): """ @@ -1256,6 +1300,18 @@ def evaluate(x, y, amplitude, x_0, y_0, x_width, y_width): y <= y_0 + y_width / 2.) return np.select([np.logical_and(x_range, y_range)], [amplitude], 0) + def bounding_box_default(self): + """ + Tuple defining the default ``bounding_box``. + + ``((y_low, y_high), (x_low, x_high))`` + """ + + dx = self.x_width / 2 + dy = self.y_width / 2 + + return ((self.y_0 - dy, self.y_0 + dy), + (self.x_0 - dx, self.x_0 + dx)) class Trapezoid1D(Fittable1DModel): """ @@ -1302,6 +1358,16 @@ def evaluate(x, amplitude, x_0, width, slope): val_c = slope * (x4 - x) return np.select([range_a, range_b, range_c], [val_a, val_b, val_c]) + def bounding_box_default(self): + """ + Tuple defining the default ``bounding_box`` limits. + + ``(x_low, x_high))`` + """ + + dx = self.width / 2 + self.amplitude / self.slope + + return (self.x_0 - dx, self.x_0 + dx) class TrapezoidDisk2D(Fittable2DModel): """ @@ -1330,7 +1396,7 @@ class TrapezoidDisk2D(Fittable2DModel): y_0 = Parameter(default=0) R_0 = Parameter(default=1) slope = Parameter(default=1) - + @staticmethod def evaluate(x, y, amplitude, x_0, y_0, R_0, slope): """Two dimensional Trapezoid Disk model function""" @@ -1342,6 +1408,17 @@ def evaluate(x, y, amplitude, x_0, y_0, R_0, slope): val_2 = amplitude + slope * (R_0 - r) return np.select([range_1, range_2], [val_1, val_2]) + def bounding_box_default(self): + """ + Tuple defining the default ``bounding_box``. + + ``((y_low, y_high), (x_low, x_high))`` + """ + + dr = self.R_0 + self.amplitude / self.slope + + return ((self.y_0 - dr, self.y_0 + dr), + (self.x_0 - dr, self.x_0 + dr)) class MexicanHat1D(Fittable1DModel): """ @@ -1382,7 +1459,6 @@ def evaluate(x, amplitude, x_0, sigma): xx_ww = (x - x_0) ** 2 / (2 * sigma ** 2) return amplitude * (1 - 2 * xx_ww) * np.exp(-xx_ww) - class MexicanHat2D(Fittable2DModel): """ Two dimensional symmetric Mexican Hat model. @@ -1426,7 +1502,6 @@ def evaluate(x, y, amplitude, x_0, y_0, sigma): rr_ww = ((x - x_0) ** 2 + (y - y_0) ** 2) / (2 * sigma ** 2) return amplitude * (1 - rr_ww) * np.exp(- rr_ww) - class AiryDisk2D(Fittable2DModel): """ Two dimensional Airy disk model. @@ -1516,7 +1591,6 @@ def evaluate(cls, x, y, amplitude, x_0, y_0, radius): z *= amplitude return z - class Moffat1D(Fittable1DModel): """ One dimensional Moffat model. @@ -1568,7 +1642,6 @@ def fit_deriv(x, amplitude, x_0, gamma, alpha): d_alpha = -amplitude * d_A * np.log(1 + (x - x_0) ** 2 / gamma ** 2) return [d_A, d_x_0, d_gamma, d_alpha] - class Moffat2D(Fittable2DModel): """ Two dimensional Moffat model. @@ -1627,7 +1700,6 @@ def fit_deriv(x, y, amplitude, x_0, y_0, gamma, alpha): d_gamma = 2 * amplitude * alpha * d_A * (rr_gg / (gamma * (1 + rr_gg))) return [d_A, d_x_0, d_y_0, d_gamma, d_alpha] - class Sersic2D(Fittable2DModel): r""" Two dimensional Sersic surface brightness profile. @@ -1735,7 +1807,6 @@ def evaluate(cls, x, y, amplitude, r_eff, n, x_0, y_0, ellip, theta): return amplitude * np.exp(-bn * (z ** (1 / n) - 1)) - @deprecated('1.0', alternative='astropy.modeling.models.custom_model', pending=True) def custom_model_1d(func, func_fit_deriv=None): diff --git a/astropy/modeling/parameters.py b/astropy/modeling/parameters.py index 1ae3c684a852..53ec88d50e7e 100644 --- a/astropy/modeling/parameters.py +++ b/astropy/modeling/parameters.py @@ -17,8 +17,7 @@ import numpy as np -from ..utils import isiterable -from ..utils.compat import ignored +from ..utils import isiterable, OrderedDescriptor from ..extern import six __all__ = ['Parameter', 'InputParameterError'] @@ -55,7 +54,7 @@ def _tofloat(value): return value -class Parameter(object): +class Parameter(OrderedDescriptor): """ Wraps individual parameters. @@ -83,6 +82,13 @@ class Parameter(object): ---------- name : str parameter name + + .. warning:: + + The fact that `Parameter` accepts ``name`` as an argument is an + implementation detail, and should not be used directly. When + defining a new `Model` class, parameter names are always + automatically defined by the class attribute they're assigned to. description : str parameter description default : float or array @@ -121,8 +127,9 @@ class Parameter(object): constraint (which is represented as a 2-tuple). """ - # See the _nextid classmethod - _nextid = 1 + # Settings for OrderedDescriptor + _class_attribute_ = '_parameters_' + _name_attribute_ = '_name' def __init__(self, name='', description='', default=None, getter=None, setter=None, fixed=False, tied=False, min=None, max=None, @@ -163,9 +170,7 @@ def __init__(self, name='', description='', default=None, getter=None, # Only Parameters declared as class-level descriptors require # and ordering ID - if model is None: - self._order = self._get_nextid() - else: + if model is not None: self._bind(model) def __get__(self, obj, objtype): @@ -579,19 +584,6 @@ def _raw_value(self): return self._get_model_value(self._model) - @classmethod - def _get_nextid(cls): - """Returns a monotonically increasing ID used to order Parameter - descriptors declared at the class-level of Model subclasses. - - This allows the desired parameter order to be determined without - having to list it manually in the param_names class attribute. - """ - - nextid = cls._nextid - cls._nextid += 1 - return nextid - def _bind(self, model): """ Bind the `Parameter` to a specific `Model` instance; don't use this @@ -700,64 +692,113 @@ def __nonzero__(self): __bool__ = __nonzero__ def __add__(self, val): + if self._model is None: + # If we don't do this, __add__ will raise an AttributeError instead + # (from self.value) which is strange and unexpected + return NotImplemented return self.value + val def __radd__(self, val): + if self._model is None: + return NotImplemented return val + self.value def __sub__(self, val): + if self._model is None: + return NotImplemented return self.value - val def __rsub__(self, val): + if self._model is None: + return NotImplemented return val - self.value def __mul__(self, val): + if self._model is None: + return NotImplemented return self.value * val def __rmul__(self, val): + if self._model is None: + return NotImplemented return val * self.value def __pow__(self, val): + if self._model is None: + return NotImplemented return self.value ** val def __rpow__(self, val): + if self._model is None: + return NotImplemented return val ** self.value def __div__(self, val): + if self._model is None: + return NotImplemented return self.value / val def __rdiv__(self, val): + if self._model is None: + return NotImplemented return val / self.value def __truediv__(self, val): + if self._model is None: + return NotImplemented return self.value / val def __rtruediv__(self, val): + if self._model is None: + return NotImplemented return val / self.value def __eq__(self, val): if self._model is None: - return super(Parameter, self).__eq__(val) + return NotImplemented return self.__array__() == val def __ne__(self, val): + if self._model is None: + return NotImplemented + return self.__array__() != val def __lt__(self, val): + # Because OrderedDescriptor uses __lt__ to work, we need to call the + # super method, but only when not bound to an instance anyways + if self._model is None: + return super(Parameter, self).__lt__(val) + return self.__array__() < val def __gt__(self, val): + if self._model is None: + return NotImplemented + return self.__array__() > val def __le__(self, val): + if self._model is None: + return NotImplemented + return self.__array__() <= val def __ge__(self, val): + if self._model is None: + return NotImplemented + return self.__array__() >= val def __neg__(self): + if self._model is None: + return NotImplemented + return -self.value def __abs__(self): + if self._model is None: + return NotImplemented + return np.abs(self.value) diff --git a/astropy/modeling/polynomial.py b/astropy/modeling/polynomial.py index eb2d1967c2f0..8de208589e89 100644 --- a/astropy/modeling/polynomial.py +++ b/astropy/modeling/polynomial.py @@ -13,7 +13,7 @@ from .functional_models import Shift from .parameters import Parameter from .utils import poly_map_domain, comb, check_broadcast -from ..utils import lazyproperty, indent +from ..utils import indent __all__ = [ @@ -40,7 +40,7 @@ class PolynomialBase(FittableModel): linear = True col_fit_deriv = False - @lazyproperty + @property def param_names(self): """Coefficient names generated based on the model's polynomial degree and number of dimensions. diff --git a/astropy/modeling/tests/test_core.py b/astropy/modeling/tests/test_core.py index 6018484f6ee9..41203156b358 100644 --- a/astropy/modeling/tests/test_core.py +++ b/astropy/modeling/tests/test_core.py @@ -9,11 +9,10 @@ import pytest import numpy as np from numpy.testing.utils import assert_allclose -from ..core import Model, InputParameterError, custom_model, render_model +from ..core import Model, InputParameterError, custom_model from ..parameters import Parameter from .. import models - class NonFittableModel(Model): """An example class directly subclassing Model for testing.""" @@ -208,43 +207,52 @@ def test_custom_inverse(): def test_render_model_2d(): - imshape = (71, 141) image = np.zeros(imshape) coords = y, x = np.indices(imshape) - model = models.Gaussian2D(x_stddev=6.1, y_stddev=3.9, theta=np.pi / 4) + model = models.Gaussian2D(x_stddev=6.1, y_stddev=3.9, theta=np.pi / 3) # test points for edges ye, xe = [0, 35, 70], [0, 70, 140] # test points for floating point positions yf, xf = [35.1, 35.5, 35.9], [70.1, 70.5, 70.9] - test_pts = [(a, b) for a in xe for b in ye] + [(a, b) for a in xf for b in yf] + test_pts = [(a, b) for a in xe for b in ye] + test_pts += [(a, b) for a in xf for b in yf] for x0, y0 in test_pts: model.x_mean = x0 model.y_mean = y0 expected = model(x, y) - for im in [image, None]: - for xy in [coords, None]: + for xy in [coords, None]: + for im in [image.copy(), None]: if (im is None) & (xy is None): # this case is tested in Fittable2DModelTester continue - actual = render_model(model, arr=image, coords=xy) + actual = model.render(out=im, coords=xy) + if im is None: + assert_allclose(actual, model.render(coords=xy)) # assert images match - assert_allclose(expected, actual, atol=2e-7) - # assert flux conserved - assert ((np.sum(expected) - np.sum(actual)) / np.sum(expected)) < 1e-7 + assert_allclose(expected, actual, atol=3e-7) + # assert model fully captured + if (x0, y0) == (70, 35): + boxed = model.render() + flux = np.sum(expected) + assert ((flux - np.sum(boxed)) / flux) < 1e-7 + # test an error is raised when the bounding box is larger than the input array + try: + actual = model.render(out=np.zeros((1, 1))) + except ValueError: + pass def test_render_model_1d(): - npix = 101 image = np.zeros(npix) coords = np.arange(npix) - model = models.Gaussian1D(stddev=49.5) + model = models.Gaussian1D() # test points test_pts = [0, 49.1, 49.5, 49.9, 100] @@ -252,20 +260,23 @@ def test_render_model_1d(): # test widths test_stdv = np.arange(5.5, 6.7, .2) - for x0, stdv in zip(test_pts, test_stdv): + for x0, stdv in [(p, s) for p in test_pts for s in test_stdv]: model.mean = x0 model.stddev = stdv expected = model(coords) - for im in [image, None]: - for x in [coords, None]: + for x in [coords, None]: + for im in [image.copy(), None]: if (im is None) & (x is None): # this case is tested in Fittable1DModelTester continue - actual = render_model(model, arr=image, coords=x) + actual = model.render(out=im, coords=x) # assert images match - assert_allclose(expected, actual, atol=2e-7) - # assert flux conserved - assert ((np.sum(expected) - np.sum(actual)) / np.sum(expected)) < 1e-7 + assert_allclose(expected, actual, atol=3e-7) + # assert model fully captured + if (x0, stdv) == (49.5, 5.5): + boxed = model.render() + flux = np.sum(expected) + assert ((flux - np.sum(boxed)) / flux) < 1e-7 def test_render_model_3d(): @@ -278,7 +289,11 @@ def ellipsoid(x, y, z, x0=13., y0=10., z0=8., a=4., b=3., c=2., amp=1.): val = (rsq < 1) * amp return val - Ellipsoid3D = models.custom_model(ellipsoid) + class Ellipsoid3D(custom_model(ellipsoid)): + def bounding_box_default(self): + return ((self.z0 - self.c, self.z0 + self.c), + (self.y0 - self.b, self.y0 + self.b), + (self.x0 - self.a, self.x0 + self.a)) model = Ellipsoid3D() @@ -290,17 +305,20 @@ def ellipsoid(x, y, z, x0=13., y0=10., z0=8., a=4., b=3., c=2., amp=1.): test_pts = [(x, y, z) for x in xe for y in ye for z in ze] test_pts += [(x, y, z) for x in xf for y in yf for z in zf] - for x0, y0, z0 in [(8,10,13)]:#test_pts: + for x0, y0, z0 in test_pts: model.x0 = x0 model.y0 = y0 model.z0 = z0 expected = model(*coords[::-1]) - for im in [image, None]: - for c in [coords, None]: + for c in [coords, None]: + for im in [image.copy(), None]: if (im is None) & (c is None): continue - actual = render_model(model, arr=image, coords=c) + actual = model.render(out=im, coords=c) + boxed = model.render() # assert images match assert_allclose(expected, actual) - # assert flux conserved - assert ((np.sum(expected) - np.sum(actual)) / np.sum(expected)) == 0 + # assert model fully captured + if (z0, y0, x0) == (8, 10, 13): + boxed = model.render() + assert (np.sum(expected) - np.sum(boxed)) == 0 diff --git a/astropy/modeling/tests/test_functional_models.py b/astropy/modeling/tests/test_functional_models.py index 9d2961c653aa..514b03e82c69 100644 --- a/astropy/modeling/tests/test_functional_models.py +++ b/astropy/modeling/tests/test_functional_models.py @@ -32,6 +32,7 @@ def test_GaussianAbsorption1D(): assert_allclose(g_ab(xx), 1 - g_em(xx)) assert_allclose(g_ab.fit_deriv(xx[0], 0.8, 3000, 20), -np.array(g_em.fit_deriv(xx[0], 0.8, 3000, 20))) + assert g_ab.bounding_box_default() == g_em.bounding_box def test_Gaussian2D(): @@ -117,6 +118,7 @@ def test_Ellipse2D(): e = em(x, y) assert np.all(e[e > 0] == amplitude) assert e[y0, x0] == amplitude + assert em.bounding_box_default() == em.bounding_box rotation = models.Rotation2D(angle=theta.degree) point1 = [2, 0] # Rotation2D center is (0, 0) diff --git a/astropy/modeling/tests/test_models.py b/astropy/modeling/tests/test_models.py index f8584f291852..ec5dd85a38f7 100644 --- a/astropy/modeling/tests/test_models.py +++ b/astropy/modeling/tests/test_models.py @@ -22,7 +22,7 @@ from .example_models import models_1D, models_2D from .. import (fitting, models, LabeledInput, SerialCompositeModel, SummedCompositeModel) -from ..core import FittableModel, render_model +from ..core import FittableModel from ..polynomial import PolynomialBase from ...tests.helper import pytest @@ -213,43 +213,36 @@ def SineModel(x, amplitude=4, frequency=1): def test_custom_model_bounding_box(): """Test bounding box evaluation for a 3D model""" - def ellipsoid(x, y, z, x0=13., y0=10., z0=8., a=4., b=3., c=2., amp=1.): + def ellipsoid(x, y, z, x0=13, y0=10, z0=8, a=4, b=3, c=2, amp=1): rsq = ((x - x0) / a) ** 2 + ((y - y0) / b) ** 2 + ((z - z0) / c) ** 2 val = (rsq < 1) * amp return val - def ellipsoid_bbox(self): - return ((self.z0 - self.c, self.z0 + self.c), - (self.y0 - self.b, self.y0 + self.b), - (self.x0 - self.a, self.x0 + self.a)) - - Ellipsoid3D = models.custom_model(ellipsoid) - Ellipsoid3D.bounding_box_default = ellipsoid_bbox + class Ellipsoid3D(models.custom_model(ellipsoid)): + def bounding_box_default(self): + return ((self.z0 - self.c, self.z0 + self.c), + (self.y0 - self.b, self.y0 + self.b), + (self.x0 - self.a, self.x0 + self.a)) model = Ellipsoid3D() - model.bounding_box = 'auto' bbox = model.bounding_box - if bbox is None: - pytest.skip("Bounding_box is not defined for model.") - # Check for exact agreement within bounded region - zlim, ylim, xlim = bbox - dx = np.ceil((xlim[1] - xlim[0]) / 2) - dy = np.ceil((ylim[1] - ylim[0]) / 2) - dz = np.ceil((zlim[1] - zlim[0]) / 2) - z0, y0, x0 = np.mean(bbox, axis=1).astype(int) - z, y, x = np.mgrid[z0 - dz: z0 + dz + 1, y0 - dy: - y0 + dy + 1, x0 - dx: x0 + dx + 1] + assert bbox == model.bounding_box_default() - expected = model(x, y, z) - actual = render_model(model) + zlim, ylim, xlim = bbox + dz, dy, dx = np.diff(bbox) / 2 + z1, y1, x1 = np.mgrid[slice(zlim[0], zlim[1] + 1), + slice(ylim[0], ylim[1] + 1), + slice(xlim[0], xlim[1] + 1)] + z2, y2, x2 = np.mgrid[slice(zlim[0] - dz, zlim[1] + dz + 1), + slice(ylim[0] - dy, ylim[1] + dy + 1), + slice(xlim[0] - dx, xlim[1] + dx + 1)] - utils.assert_allclose(actual, expected, rtol=0, atol=0) + arr = model(x2, y2, z2) + sub_arr = model(x1, y1, z1) - # check result with no bounding box defined - model.bounding_box = None - actual = render_model(model, coords=[z,y,x]) - utils.assert_allclose(actual, expected, rtol=0, atol=0) + # check for flux agreement + assert abs(arr.sum() - sub_arr.sum()) < arr.sum() * 1e-7 class Fittable2DModelTester(object): @@ -297,27 +290,36 @@ def test_bounding_box2D(self, model_class, test_parameters): model = create_model(model_class, test_parameters) - bbox = model.bounding_box - if bbox is None: + # testing setter + model.bounding_box = ((-5, 5), (-5, 5)) + model.bounding_box = None + model.bounding_box = 'auto' + + # test the exception of dimensions don't match + try: + model.bounding_box = (-5, 5) + except ValueError: + pass + + try : + bbox = model.bounding_box + except NotImplementedError: pytest.skip("Bounding_box is not defined for model.") - # Check for exact agreement within bounded region - xlim, ylim = bbox - dx = np.ceil((xlim[1] - xlim[0]) / 2) - dy = np.ceil((ylim[1] - ylim[0]) / 2) - x0, y0 = np.mean(bbox, axis=1).astype(int) - y, x = np.mgrid[y0 - dy: y0 + dy + 1, - x0 - dx: x0 + dx + 1] + assert bbox == model.bounding_box_default() - expected = model(x, y) - actual = render_model(model) + ylim, xlim = bbox + dy, dx = np.diff(bbox)/2 + y1, x1 = np.mgrid[slice(ylim[0], ylim[1] + 1), + slice(xlim[0], xlim[1] + 1)] + y2, x2 = np.mgrid[slice(ylim[0] - dy, ylim[1] + dy + 1), + slice(xlim[0] - dx, xlim[1] + dx + 1)] - utils.assert_allclose(actual, expected, rtol=0, atol=0) + arr = model(x2, y2) + sub_arr = model(x1, y1) - # check result with no bounding box defined - model.bounding_box = None - actual = render_model(model, coords=[y, x]) - utils.assert_allclose(actual, expected, rtol=0, atol=0) + # check for flux agreement + assert abs(arr.sum() - sub_arr.sum()) < arr.sum() * 1e-7 @pytest.mark.skipif('not HAS_SCIPY') def test_fitter2D(self, model_class, test_parameters): @@ -461,24 +463,33 @@ def test_bounding_box1D(self, model_class, test_parameters): model = create_model(model_class, test_parameters) - bbox = model.bounding_box - if bbox is None: + # testing setter + model.bounding_box = (-5, 5) + model.bounding_box = None + model.bounding_box = 'auto' + + # test exception if dimensions don't match + try: + model.bounding_box = 5 + except ValueError: + pass + + try: + bbox = model.bounding_box + except NotImplementedError: pytest.skip("Bounding_box is not defined for model.") - # Check for exact agreement within bounded region - dx = np.ceil(np.diff(model.bounding_box)[0] / 2) - x0 = int(np.mean(bbox)) - x = np.mgrid[x0 - dx: x0 + dx + 1] + assert bbox == model.bounding_box_default() - expected = model(x) - actual = render_model(model) + dx = np.diff(bbox) / 2 + x1 = np.mgrid[slice(bbox[0], bbox[1] + 1)] + x2 = np.mgrid[slice(bbox[0] - dx, bbox[1] + dx + 1)] - utils.assert_allclose(actual, expected, rtol=0, atol=0) + arr = model(x2) + sub_arr = model(x1) - # check result with no bounding box defined - model.bounding_box = None - actual = render_model(model, coords=x) - utils.assert_allclose(actual, expected, rtol=0, atol=0) + # check for flux agreement + assert abs(arr.sum() - sub_arr.sum()) < arr.sum() * 1e-7 @pytest.mark.skipif('not HAS_SCIPY') def test_fitter1D(self, model_class, test_parameters): diff --git a/astropy/modeling/tests/test_parameters.py b/astropy/modeling/tests/test_parameters.py index 8b56e272d6c0..abdbb6556167 100644 --- a/astropy/modeling/tests/test_parameters.py +++ b/astropy/modeling/tests/test_parameters.py @@ -100,29 +100,6 @@ def test_parameter_operators(): assert abs(par) == abs(num) -def test_parameter_name_attribute_mismatch(): - """ - It should not be possible to define Parameters on a model with different - names from the attributes they are assigned to. - """ - - def make_bad_class(): - class BadModel(Model): - foo = Parameter('bar') - - def __call__(self): pass - - def make_good_class(): - class GoodModel(Model): - # This is redundant but okay - foo = Parameter('foo') - - def __call__(self): pass - - make_good_class() - pytest.raises(ModelDefinitionError, make_bad_class) - - class TestParameters(object): def setup_class(self): diff --git a/astropy/modeling/tests/test_utils.py b/astropy/modeling/tests/test_utils.py index cf6e3955e8ef..d523003d8585 100644 --- a/astropy/modeling/tests/test_utils.py +++ b/astropy/modeling/tests/test_utils.py @@ -8,7 +8,6 @@ import numpy as np from ..utils import ExpressionTree as ET, ellipse_extent -from ..core import render_model from ..models import Ellipse2D @@ -91,7 +90,7 @@ def test_ellipse_extent(): model.bounding_box = limits - actual = render_model(model, coords=coords) + actual = model.render(coords=coords) expected = model(x, y) diff --git a/astropy/nddata/tests/test_utils.py b/astropy/nddata/tests/test_utils.py index b54b34bc2066..44e7d0d9ee8c 100644 --- a/astropy/nddata/tests/test_utils.py +++ b/astropy/nddata/tests/test_utils.py @@ -10,6 +10,7 @@ Cutout2D) from ...wcs import WCS from ...coordinates import SkyCoord +from ... import units as u try: import skimage @@ -336,23 +337,42 @@ def setup_class(self): self.wcs = wcs def test_cutout(self): - position = (2.1, 1.9) - shape = (3, 3) - c = Cutout2D(self.data, position, shape) - assert c.data.shape == shape - assert c.data[1, 1] == 10 - assert c.origin_original == (1, 1) - assert c.origin_cutout == (0, 0) - assert c.input_position_original == position - assert_allclose(c.input_position_cutout, (1.1, 0.9)) - assert c.position_original == (2., 2.) - assert c.position_cutout == (1., 1.) - assert c.center_original == (2., 2.) - assert c.center_cutout == (1., 1.) - assert c.bbox_original == ((1, 3), (1, 3)) - assert c.bbox_cutout == ((0, 2), (0, 2)) - assert c.slices_original == (slice(1, 4), slice(1, 4)) - assert c.slices_cutout == (slice(0, 3), slice(0, 3)) + for shape in [(3, 3), (3*u.pixel, 3*u.pix)]: + position = (2.1, 1.9) + c = Cutout2D(self.data, position, shape) + assert c.data.shape == (3, 3) + assert c.data[1, 1] == 10 + assert c.origin_original == (1, 1) + assert c.origin_cutout == (0, 0) + assert c.input_position_original == position + assert_allclose(c.input_position_cutout, (1.1, 0.9)) + assert c.position_original == (2., 2.) + assert c.position_cutout == (1., 1.) + assert c.center_original == (2., 2.) + assert c.center_cutout == (1., 1.) + assert c.bbox_original == ((1, 3), (1, 3)) + assert c.bbox_cutout == ((0, 2), (0, 2)) + assert c.slices_original == (slice(1, 4), slice(1, 4)) + assert c.slices_cutout == (slice(0, 3), slice(0, 3)) + + def test_cutout_sidelength(self): + for side_length in [3, 3*u.pixel]: + position = (2.1, 1.9) + c = Cutout2D(self.data, position, side_length=side_length) + assert c.data.shape == (3, 3) + assert c.data[1, 1] == 10 + assert c.origin_original == (1, 1) + assert c.origin_cutout == (0, 0) + assert c.input_position_original == position + assert_allclose(c.input_position_cutout, (1.1, 0.9)) + assert c.position_original == (2., 2.) + assert c.position_cutout == (1., 1.) + assert c.center_original == (2., 2.) + assert c.center_cutout == (1., 1.) + assert c.bbox_original == ((1, 3), (1, 3)) + assert c.bbox_cutout == ((0, 2), (0, 2)) + assert c.slices_original == (slice(1, 4), slice(1, 4)) + assert c.slices_cutout == (slice(0, 3), slice(0, 3)) def test_cutout_trim_overlap(self): shape = (3, 3) diff --git a/astropy/nddata/utils.py b/astropy/nddata/utils.py index ac39e9f8d438..39c4d4e42cb2 100644 --- a/astropy/nddata/utils.py +++ b/astropy/nddata/utils.py @@ -10,6 +10,7 @@ from astropy.utils import lazyproperty from astropy.coordinates import SkyCoord from astropy.wcs.utils import skycoord_to_pixel +from astropy import units as u __all__ = ['extract_array', 'add_array', 'subpixel_indices', @@ -474,8 +475,8 @@ def block_replicate(data, block_size, conserve_sum=True): class Cutout2D(object): """Create a cutout object from a 2D array.""" - def __init__(self, data, position, shape, wcs=None, mode='trim', - fill_value=np.nan, copy=False): + def __init__(self, data, position, shape=None, side_length=None, wcs=None, + mode='trim', fill_value=np.nan, copy=False): """ The returned object will contain a 2D cutout array. If ``copy=False`` (default), the cutout array is a view into the @@ -486,6 +487,9 @@ def __init__(self, data, position, shape, wcs=None, mode='trim', object will also contain a copy of the original WCS, but updated for the cutout array. + The shape of the cutout is determined by the ``shape`` parameter, or + ``side_length`` for a square cutout. + For example usage, see :ref:`cutout_images`. .. warning:: @@ -505,10 +509,17 @@ def __init__(self, data, position, shape, wcs=None, mode='trim', `~astropy.coordinates.SkyCoord`, in which case ``wcs`` is a required input. - shape : tuple + shape : tuple, optional The shape (``(ny, nx)``) of the cutout array in pixel coordinates (but see the ``mode`` keyword for additional - details). + details). May be specified as a `~astropy.units.Quantity` + equivalent to pixels. + + side_length : scalar, optional + The length (in pixel coordinates) of a side in a square cutout + array. ``shape`` will be set to ``(side_length, side_length)``. + See the ``mode`` keyword for additional details. May be specified + as a `~astropy.units.Quantity` equivalent to pixels. wcs : `~astropy.wcs.WCS`, optional A WCS object associated with the input ``data`` array. If @@ -551,6 +562,7 @@ def __init__(self, data, position, shape, wcs=None, mode='trim', -------- >>> import numpy as np >>> from astropy.nddata.utils import Cutout2D + >>> from astropy import units as u >>> data = np.arange(20.).reshape(5, 4) >>> c1 = Cutout2D(data, (2, 2), (3, 3)) >>> print(c1.data) @@ -565,16 +577,21 @@ def __init__(self, data, position, shape, wcs=None, mode='trim', >>> print(c1.origin_original) (1, 1) - >>> c2 = Cutout2D(data, (0, 0), (3, 3)) + >>> c2 = Cutout2D(data, (0, 0), shape=(3*u.pixel, 3*u.pixel)) >>> print(c2.data) [[ 0. 1.] [ 4. 5.]] - >>> c3 = Cutout2D(data, (0, 0), (3, 3), mode='partial') + >>> c3 = Cutout2D(data, (0, 0), shape=(3, 3), mode='partial') >>> print(c3.data) [[ nan nan nan] [ nan 0. 1.] [ nan 4. 5.]] + + >>> c4 = Cutout2D(data, (0, 0), side_length=3) + >>> print(c4.data) + [[ 0. 1.] + [ 4. 5.]] """ if isinstance(position, SkyCoord): @@ -583,11 +600,22 @@ def __init__(self, data, position, shape, wcs=None, mode='trim', 'SkyCoord') position = skycoord_to_pixel(position, wcs, mode='all') # (x, y) + if side_length is None and shape is None: + raise ValueError("Either side_length or shape must be specified") + + if side_length is not None and shape is not None: + raise ValueError("Cannot specify both side_length and shape") + + if side_length is not None: + shape = (side_length, side_length) + + shape = [x.value if u.pixel.is_equivalent(x) else x for x in shape] + # extract_array and overlap_slices use (y, x) positions pos = position[::-1] data = np.asanyarray(data) - cutout_data, input_position_cutout = extract_array( + cutout_data, input_position_cutout = extract_array( data, shape, pos, mode=mode, fill_value=fill_value, return_position=True) if copy: @@ -595,8 +623,9 @@ def __init__(self, data, position, shape, wcs=None, mode='trim', self.data = cutout_data self.input_position_cutout = input_position_cutout[::-1] # (x, y) - slices_original, slices_cutout = overlap_slices(data.shape, shape, - pos, mode=mode) + slices_original, slices_cutout = overlap_slices( + data.shape, shape, pos, mode=mode) + self.slices_original = slices_original self.slices_cutout = slices_cutout @@ -604,18 +633,17 @@ def __init__(self, data, position, shape, wcs=None, mode='trim', self.input_position_original = position self.shape_input = shape - ((self.xmin_original, self.xmax_original), - (self.ymin_original, self.ymax_original)) = self.bbox_original + ((self.ymin_original, self.ymax_original), + (self.xmin_original, self.xmax_original)) = self.bbox_original - ((self.xmin_cutout, self.xmax_cutout), - (self.ymin_cutout, self.ymax_cutout)) = self.bbox_cutout + ((self.ymin_cutout, self.ymax_cutout), + (self.xmin_cutout, self.xmax_cutout)) = self.bbox_cutout # the true origin pixel of the cutout array, including any # filled cutout values - self._origin_original_true = (self.origin_original[0] - - self.slices_cutout[1].start, - self.origin_original[1] - - self.slices_cutout[0].start) + self._origin_original_true = ( + self.origin_original[0] - self.slices_cutout[1].start, + self.origin_original[1] - self.slices_cutout[0].start) if wcs is not None: self.wcs = deepcopy(wcs) @@ -715,14 +743,14 @@ def _calc_center(slices): @staticmethod def _calc_bbox(slices): """ - Calculate a minimal bounding box in the form ``((xmin, xmax), - (ymin, ymax))``. Note these are pixel locations, not slice + Calculate a minimal bounding box in the form ``((ymin, ymax), + (xmin, xmax))``. Note these are pixel locations, not slice indices. For ``mode='partial'``, the bounding box indices are for the valid (non-filled) cutout values. """ # (stop - 1) to return the max pixel location, not the slice index - return ((slices[1].start, slices[1].stop - 1), - (slices[0].start, slices[0].stop - 1)) + return ((slices[0].start, slices[0].stop - 1), + (slices[1].start, slices[1].stop - 1)) @lazyproperty def origin_original(self): @@ -782,7 +810,7 @@ def center_cutout(self): @lazyproperty def bbox_original(self): """ - The bounding box ``(ymin, xmin, ymax, xmax)`` of the minimal + The bounding box ``((ymin, ymax), (xmin, xmax))`` of the minimal rectangular region of the cutout array with respect to the original array. For ``mode='partial'``, the bounding box indices are for the valid (non-filled) cutout values. @@ -792,7 +820,7 @@ def bbox_original(self): @lazyproperty def bbox_cutout(self): """ - The bounding box ``(ymin, xmin, ymax, xmax)`` of the minimal + The bounding box ``((ymin, ymax), (xmin, xmax))`` of the minimal rectangular region of the cutout array with respect to the cutout array. For ``mode='partial'``, the bounding box indices are for the valid (non-filled) cutout values. diff --git a/astropy/stats/sigma_clipping.py b/astropy/stats/sigma_clipping.py index 80b7d46ea420..808de07dc7c2 100644 --- a/astropy/stats/sigma_clipping.py +++ b/astropy/stats/sigma_clipping.py @@ -38,8 +38,9 @@ def _sigma_clip(data, sigma=3, sigma_lower=None, sigma_upper=None, iters=5, Perform sigma-clipping on the provided data. The data will be iterated over, each time rejecting points that are - discrepant by more than a specified number of standard deviations - from a center value. + discrepant by more than a specified number of standard deviations from a + center value. If the data contains invalid values (NaNs or infs), + they are automatically masked before performing the sigma clipping. .. note:: `scipy.stats.sigmaclip @@ -174,6 +175,11 @@ def _sigma_clip(data, sigma=3, sigma_lower=None, sigma_upper=None, iters=5, stdfunc = lambda d: np.expand_dims(stdfunc_in(d, axis=axis), axis=axis) + if np.any(~np.isfinite(data)): + data = np.ma.masked_invalid(data) + warnings.warn("Input data contains invalid values (NaNs or infs), " + "which were automatically masked.", AstropyUserWarning) + filtered_data = np.ma.array(data, copy=copy) if iters is None: diff --git a/astropy/stats/tests/test_sigma_clipping.py b/astropy/stats/tests/test_sigma_clipping.py index 46d5a1f5bca0..3ff680aac0cc 100644 --- a/astropy/stats/tests/test_sigma_clipping.py +++ b/astropy/stats/tests/test_sigma_clipping.py @@ -104,3 +104,22 @@ def test_sigma_clipped_stats(): assert result2[0] == 1. assert result2[1] == 1. assert result2[2] == 0. + + +def test_invalid_sigma_clip(): + """Test sigma_clip of data containing invalid values.""" + + data = np.ones((5, 5)) + data[2, 2] = 1000 + data[3, 4] = np.nan + data[1, 1] = np.inf + + result = sigma_clip(data) + + # Pre #4051 if data contains any NaN or infs sigma_clip returns the mask + # containig `False` only or TypeError if data also contains a masked value. + + assert result.mask[2, 2] == True + assert result.mask[3, 4] == True + assert result.mask[1, 1] == True + diff --git a/astropy/tests/coveragerc b/astropy/tests/coveragerc index 807c9926578e..e3aadab668e4 100644 --- a/astropy/tests/coveragerc +++ b/astropy/tests/coveragerc @@ -10,6 +10,7 @@ omit = astropy/utils/compat/* astropy/version* astropy/wcs/docstrings* + astropy/_erfa/erfa_generator.py [report] exclude_lines = diff --git a/astropy/tests/pytest_plugins.py b/astropy/tests/pytest_plugins.py index 0fa02492d8ad..faf99edfb12d 100644 --- a/astropy/tests/pytest_plugins.py +++ b/astropy/tests/pytest_plugins.py @@ -30,6 +30,7 @@ from .output_checker import AstropyOutputChecker, FIX, FLOAT_CMP from ..utils import OrderedDict from ..utils.argparse import writeable_directory +from ..utils.introspection import resolve_name # Needed for Python 2.6 compatibility try: @@ -581,7 +582,7 @@ def pytest_report_header(config): for module_display, module_name in six.iteritems(PYTEST_HEADER_MODULES): try: - module = __import__(module_name) + module = resolve_name(module_name) except ImportError: s += "{0}: not available\n".format(module_display) else: diff --git a/astropy/utils/console.py b/astropy/utils/console.py index cf7fbc7ce00a..1ad8a74b868b 100644 --- a/astropy/utils/console.py +++ b/astropy/utils/console.py @@ -32,12 +32,18 @@ IPythonIOStream = None else: try: - from IPython.zmq.iostream import OutStream + from ipykernel.iostream import OutStream except ImportError: try: - from IPython.kernel.zmq.iostream import OutStream + from IPython.zmq.iostream import OutStream except ImportError: - OutStream = None + try: + from IPython.kernel.zmq.iostream import OutStream + except ImportError: + OutStream = None + + from IPython import version_info + ipython_major_version = version_info[0] if OutStream is not None: from IPython.utils import io as ipyio @@ -490,7 +496,10 @@ def __init__(self, total_or_items, ipython_widget=False, file=None): if ipython_widget: # Import only if ipython_widget, i.e., widget in IPython # notebook - from IPython.html import widgets + if ipython_major_version < 4: + from IPython.html import widgets + else: + from ipywidgets import widgets from IPython.display import display if file is None: @@ -631,10 +640,14 @@ def _update_ipython_widget(self, value=None): # if none exists. if not hasattr(self, '_widget'): # Import only if an IPython widget, i.e., widget in iPython NB - from IPython.html import widgets + if ipython_major_version < 4: + from IPython.html import widgets + self._widget = widgets.FloatProgressWidget() + else: + from ipywidgets import widgets + self._widget = widgets.FloatProgress() from IPython.display import display - self._widget = widgets.FloatProgressWidget() display(self._widget) self._widget.value = 0 diff --git a/astropy/utils/data.py b/astropy/utils/data.py index 0507485c78d4..4e102593dd0f 100644 --- a/astropy/utils/data.py +++ b/astropy/utils/data.py @@ -26,6 +26,7 @@ from .. import config as _config from ..utils.exceptions import AstropyWarning +from ..utils.introspection import resolve_name __all__ = [ @@ -123,8 +124,9 @@ def get_readable_fileobj(name_or_obj, encoding=None, cache=False, Given a filename or a readable file-like object, return a context manager that yields a readable file-like object. - This supports passing filenames, URLs, and readable file-like - objects, any of which can be compressed in gzip or bzip2. + This supports passing filenames, URLs, and readable file-like objects, + any of which can be compressed in gzip, bzip2 or lzma (xz) if the + appropriate compression libraries are provided by the Python installation. Notes ----- @@ -232,28 +234,63 @@ def get_readable_fileobj(name_or_obj, encoding=None, cache=False, try: import bz2 except ImportError: + for fd in close_fds: + fd.close() raise ValueError( ".bz2 format files are not supported since the Python " "interpreter does not include the bz2 module") try: # bz2.BZ2File does not support file objects, only filenames, so we # need to write the data to a temporary file - tmp = NamedTemporaryFile("wb", delete=False) - tmp.write(fileobj.read()) - tmp.close() - delete_fds.append(tmp) - fileobj_new = bz2.BZ2File(tmp.name, mode='rb') + with NamedTemporaryFile("wb", delete=False) as tmp: + tmp.write(fileobj.read()) + tmp.close() + fileobj_new = bz2.BZ2File(tmp.name, mode='rb') fileobj_new.read(1) # need to check that the file is really bzip2 except IOError: # invalid bzip2 file fileobj.seek(0) fileobj_new.close() + # raise else: fileobj_new.seek(0) close_fds.append(fileobj_new) fileobj = fileobj_new + elif signature[:3] == b'\xfd7z': # xz + try: + # for Python < 3.3 try backports.lzma; pyliblzma installs as lzma, + # but does not support TextIOWrapper + if sys.version_info >= (3,3,0): + import lzma + fileobj_new = lzma.LZMAFile(fileobj, mode='rb') + else: + from backports import lzma + from backports.lzma import LZMAFile + # when called with file object, returns a non-seekable instance + # need a filename here, too, so have to write the data to a + # temporary file + with NamedTemporaryFile("wb", delete=False) as tmp: + tmp.write(fileobj.read()) + tmp.close() + fileobj_new = LZMAFile(tmp.name, mode='rb') + fileobj_new.read(1) # need to check that the file is really xz + except ImportError: + for fd in close_fds: + fd.close() + raise ValueError( + ".xz format files are not supported since the Python " + "interpreter does not include the lzma module. " + "On Python versions < 3.3 consider installing backports.lzma") + except (IOError, EOFError) as e: # invalid xz file + fileobj.seek(0) + fileobj_new.close() + # should we propagate this to the caller to signal bad content? + # raise ValueError(e) + else: + fileobj_new.seek(0) + fileobj = fileobj_new - # By this point, we have a file, io.FileIO, gzip.GzipFile, or - # bz2.BZ2File instance opened in binary mode (that is, read + # By this point, we have a file, io.FileIO, gzip.GzipFile, bz2.BZ2File + # or lzma.LZMAFile instance opened in binary mode (that is, read # returns bytes). Now we need to, if requested, wrap it in a # io.TextIOWrapper so read will return unicode based on the # encoding parameter. @@ -791,7 +828,7 @@ def _find_pkg_data_path(data_name): rootpkgname = pkgname.partition('.')[0] - rootpkg = __import__(rootpkgname) + rootpkg = resolve_name(rootpkgname) module_path = os.path.dirname(module.__file__) path = os.path.join(module_path, data_name) diff --git a/astropy/utils/decorators.py b/astropy/utils/decorators.py index 59f7c1b2230a..693da64e9d81 100644 --- a/astropy/utils/decorators.py +++ b/astropy/utils/decorators.py @@ -467,7 +467,7 @@ def fget(obj): return fget -class lazyproperty(object): +class lazyproperty(property): """ Works similarly to property(), but computes the value only once. @@ -501,20 +501,14 @@ class lazyproperty(object): """ def __init__(self, fget, fset=None, fdel=None, doc=None): - self._fget = fget - self._fset = fset - self._fdel = fdel - if doc is None: - self.__doc__ = fget.__doc__ - else: - self.__doc__ = doc - self._key = self._fget.__name__ + super(lazyproperty, self).__init__(fget, fset, fdel, doc) + self._key = self.fget.__name__ def __get__(self, obj, owner=None): try: return obj.__dict__[self._key] except KeyError: - val = self._fget(obj) + val = self.fget(obj) obj.__dict__[self._key] = val return val except AttributeError: @@ -524,8 +518,8 @@ def __get__(self, obj, owner=None): def __set__(self, obj, val): obj_dict = obj.__dict__ - if self._fset: - ret = self._fset(obj, val) + if self.fset: + ret = self.fset(obj, val) if ret is not None and obj_dict.get(self._key) is ret: # By returning the value set the setter signals that it took # over setting the value in obj.__dict__; this mechanism allows @@ -534,33 +528,11 @@ def __set__(self, obj, val): obj_dict[self._key] = val def __delete__(self, obj): - if self._fdel: - self._fdel(obj) + if self.fdel: + self.fdel(obj) if self._key in obj.__dict__: del obj.__dict__[self._key] - def getter(self, fget): - return self.__ter(fget, 0) - - def setter(self, fset): - return self.__ter(fset, 1) - - def deleter(self, fdel): - return self.__ter(fdel, 2) - - def __ter(self, f, arg): - args = [self._fget, self._fset, self._fdel, self.__doc__] - args[arg] = f - cls_ns = sys._getframe(1).f_locals - for k, v in six.iteritems(cls_ns): - if v is self: - property_name = k - break - - cls_ns[property_name] = lazyproperty(*args) - - return cls_ns[property_name] - class sharedmethod(classmethod): """ diff --git a/astropy/utils/iers/iers.py b/astropy/utils/iers/iers.py index 80bfe2ac4a1c..3afbe5cdfa69 100644 --- a/astropy/utils/iers/iers.py +++ b/astropy/utils/iers/iers.py @@ -48,17 +48,20 @@ `iers.IERS_A_URL` and `iers.IERS_B_URL`:: >>> from astropy.utils.iers import IERS_A, IERS_A_URL - >>> from astropy.utils.data import download_file - >>> iers_a_file = download_file(IERS_A_URL, cache=True) # doctest: +SKIP - >>> iers_a = IERS_A.open(iers_a_file) # doctest: +SKIP + >>> iers_a = IERS_A.open(IERS_A_URL) # doctest: +SKIP """ from __future__ import (absolute_import, division, print_function, unicode_literals) +try: + from urlparse import urlparse +except ImportError: + from urllib.parse import urlparse + import numpy as np from ...table import Table, QTable -from ...utils.data import get_pkg_data_filename +from ...utils.data import get_pkg_data_filename, download_file __all__ = ['IERS', 'IERS_B', 'IERS_A', 'FROM_IERS_B', 'FROM_IERS_A', 'FROM_IERS_A_PREDICTION', @@ -95,16 +98,19 @@ class IERS(QTable): iers_table = None @classmethod - def open(cls, file=None, **kwargs): + def open(cls, file=None, cache=False, **kwargs): """Open an IERS table, reading it from a file if not loaded before. Parameters ---------- file : str or None - full path to the ascii file holding IERS data, for passing on to - the `read` class methods (further optional arguments that are - available for some IERS subclasses can be added). + full local or network path to the ascii file holding IERS data, + for passing on to the `read` class methods (further optional + arguments that are available for some IERS subclasses can be added). If None, use the default location from the `read` class method. + cache : bool + Whether to use cache. Defaults to False, since IERS files + are regularly updated. Returns ------- @@ -117,13 +123,19 @@ def open(cls, file=None, **kwargs): table if `file=None` (the default). If a table needs to be re-read from disk, pass on an explicit file - loction or use the (sub-class) close method and re-open. + location or use the (sub-class) close method and re-open. + + If the location is a network location it is first downloaded via + download_file. For the IERS class itself, an IERS_B sub-class instance is opened. """ if file is not None or cls.iers_table is None: if file is not None: - kwargs.update(file=file) + if urlparse(file).netloc: + kwargs.update(file=download_file(file, cache=cache)) + else: + kwargs.update(file=file) cls.iers_table = cls.read(**kwargs) return cls.iers_table diff --git a/astropy/utils/iers/tests/test_iers.py b/astropy/utils/iers/tests/test_iers.py index 6035252bb5cd..e6a70797fcd3 100644 --- a/astropy/utils/iers/tests/test_iers.py +++ b/astropy/utils/iers/tests/test_iers.py @@ -10,6 +10,7 @@ from .... import units as u from ....table import QTable from ....time import Time +from ....extern.six.moves import urllib FILE_NOT_FOUND_ERROR = getattr(__builtins__, 'FileNotFoundError', IOError) @@ -68,6 +69,13 @@ def test_open_filename(self): with pytest.raises(FILE_NOT_FOUND_ERROR): iers.IERS.open('surely this does not exist') + def test_open_network_url(self): + iers.IERS_A.close() + iers.IERS_A.open("file:" + urllib.request.pathname2url(IERS_A_EXCERPT)) + assert iers.IERS_A.iers_table is not None + assert isinstance(iers.IERS_A.iers_table, QTable) + iers.IERS_A.close() + class TestIERS_AExcerpt(): def test_simple(self): iers_tab = iers.IERS_A.open(IERS_A_EXCERPT) diff --git a/astropy/utils/introspection.py b/astropy/utils/introspection.py index f51376e6aa77..d778d0d7d1ef 100644 --- a/astropy/utils/introspection.py +++ b/astropy/utils/introspection.py @@ -21,7 +21,7 @@ __doctest_skip__ = ['find_current_module'] -def resolve_name(name): +def resolve_name(name, *additional_parts): """Resolve a name like ``module.object`` to an object and return it. This ends up working like ``from module import object`` but is easier @@ -37,11 +37,17 @@ def resolve_name(name): including parent modules, separated by dots. Also known as the fully qualified name of the object. + additional_parts : iterable, optional + If more than one positional arguments are given, those arguments are + automatically dotted together with ``name``. + Examples -------- >>> resolve_name('astropy.utils.introspection.resolve_name') + >>> resolve_name('astropy', 'utils', 'introspection', 'resolve_name') + Raises ------ @@ -49,29 +55,34 @@ def resolve_name(name): If the module or named object is not found. """ + additional_parts = '.'.join(additional_parts) + + if additional_parts: + name = name + '.' + additional_parts + # Note: On python 2 these must be str objects and not unicode parts = [str(part) for part in name.split('.')] if len(parts) == 1: # No dots in the name--just a straight up module import cursor = 1 - attr_name = str('') # Must not be unicode on Python 2 + fromlist=[] else: cursor = len(parts) - 1 - attr_name = parts[-1] + fromlist = [parts[-1]] module_name = parts[:cursor] while cursor > 0: try: - ret = __import__(str('.'.join(module_name)), fromlist=[attr_name]) + ret = __import__(str('.'.join(module_name)), fromlist=fromlist) break except ImportError: if cursor == 0: raise cursor -= 1 module_name = parts[:cursor] - attr_name = parts[cursor] + fromlist = [parts[cursor]] ret = '' for part in parts[cursor:]: @@ -137,7 +148,7 @@ def minversion(module, version, inclusive=True, version_path='__version__'): if '.' not in version_path: have_version = getattr(module, version_path) else: - have_version = resolve_name('.'.join([module.__name__, version_path])) + have_version = resolve_name(module.__name__, version_path) try: from pkg_resources import parse_version @@ -295,8 +306,7 @@ def find_mod_objs(modname, onlylocals=False): """ - __import__(modname) - mod = sys.modules[modname] + mod = resolve_name(modname) if hasattr(mod, '__all__'): pkgitems = [(k, mod.__dict__[k]) for k in mod.__all__] diff --git a/astropy/utils/misc.py b/astropy/utils/misc.py index eb09df19db5d..7b73533d006e 100644 --- a/astropy/utils/misc.py +++ b/astropy/utils/misc.py @@ -9,6 +9,7 @@ unicode_literals) +import abc import contextlib import difflib import inspect @@ -20,13 +21,17 @@ import traceback import unicodedata +from collections import defaultdict + from ..extern import six from ..extern.six.moves import urllib +from ..utils.compat.odict import OrderedDict __all__ = ['isiterable', 'silence', 'format_exception', 'NumpyRNGContext', 'find_api_page', 'is_path_hidden', 'walk_skip_hidden', - 'JsonCustomEncoder', 'indent', 'InheritDocstrings'] + 'JsonCustomEncoder', 'indent', 'InheritDocstrings', + 'OrderedDescriptor', 'OrderedDescriptorContainer'] def isiterable(obj): @@ -510,3 +515,289 @@ def is_public_member(key): break super(InheritDocstrings, cls).__init__(name, bases, dct) + + +@six.add_metaclass(abc.ABCMeta) +class OrderedDescriptor(object): + """ + Base class for descriptors whose order in the class body should be + preserved. Intended for use in concert with the + `OrderedDescriptorContainer` metaclass. + + Subclasses of `OrderedDescriptor` must define a value for a class attribute + called ``_class_attribute_``. This is the name of a class attribute on the + *container* class for these descriptors, which will be set to an + `~collections.OrderedDict` at class creation time. This + `~collections.OrderedDict` will contain a mapping of all class attributes + that were assigned instances of the `OrderedDescriptor` subclass, to the + instances themselves. See the documentation for + `OrderedDescriptorContainer` for a concrete example. + + Optionally, subclasses of `OrderedDescriptor` may define a value for a + class attribute called ``_name_attribute_``. This should be the name of + an attribute on instances of the subclass. When specified, during + creation of a class containing these descriptors, the name attribute on + each instance will be set to the name of the class attribute it was + assigned to on the class. + + .. note:: + + Although this class is intended for use with *descriptors* (i.e. + classes that define any of the ``__get__``, ``__set__``, or + ``__delete__`` magic methods), this base class is not itself a + descriptor, and technically this could be used for classes that are + not descriptors too. However, use with descriptors is the original + intended purpose. + """ + + # This id increments for each OrderedDescriptor instance created, so they + # are always ordered in the order they were created. Class bodies are + # guaranteed to be executed from top to bottom. Not sure if this is + # thread-safe though. + _nextid = 1 + + _class_attribute_ = abc.abstractproperty() + """ + Subclasses should define this attribute to the name of an attribute on + classes containing this subclass. That attribute will contain the mapping + of all instances of that `OrderedDescriptor` subclass defined in the class + body. If the same descriptor needs to be used with different classes, + each with different names of this attribute, multiple subclasses will be + needed. + """ + + _name_attribute_ = None + """ + Subclasses may optionally define this attribute to specify the name of an + attribute on instances of the class that should be filled with the + instance's attribute name at class creation time. + """ + + def __init__(self, *args, **kwargs): + # The _nextid attribute is shared across all subclasses so that + # different subclasses of OrderedDescriptors can be sorted correctly + # between themselves + self.__order = OrderedDescriptor._nextid + OrderedDescriptor._nextid += 1 + super(OrderedDescriptor, self).__init__() + + def __lt__(self, other): + """ + Defined for convenient sorting of `OrderedDescriptor` instances, which + are defined to sort in their creation order. + """ + + if (isinstance(self, OrderedDescriptor) and + isinstance(other, OrderedDescriptor)): + try: + return self.__order < other.__order + except AttributeError: + raise RuntimeError( + 'Could not determine ordering for {0} and {1}; at least ' + 'one of them is not calling super().__init__ in its ' + '__init__.'.format(self, other)) + else: + return NotImplemented + + +class OrderedDescriptorContainer(type): + """ + Classes should use this metaclass if they wish to use `OrderedDescriptor` + attributes, which are class attributes that "remember" the order in which + they were defined in the class body. + + Every subclass of `OrderedDescriptor` has an attribute called + ``_class_attribute_``. For example, if we have + + .. code:: python + + class ExampleDecorator(OrderedDescriptor): + _class_attribute_ = '_examples_' + + Then when a class with the `OrderedDescriptorContainer` metaclass is + created, it will automatically be assigned a class attribute ``_examples_`` + referencing an `~collections.OrderedDict` containing all instances of + ``ExampleDecorator`` defined in the class body, mapped to by the names of + the attributes they were assigned to. + + When subclassing a class with this metaclass, the descriptor dict (i.e. + ``_examples_`` in the above example) will *not* contain descriptors + inherited from the base class. That is, this only works by default with + decorators explicitly defined in the class body. However, the subclass + *may* define an attribute ``_inherit_decorators_`` which lists + `OrderedDescriptor` classes that *should* be added from base classes. + See the examples section below for an example of this. + + Examples + -------- + + >>> from astropy.extern import six + >>> from astropy.utils import OrderedDescriptor, OrderedDescriptorContainer + >>> class TypedAttribute(OrderedDescriptor): + ... \"\"\" + ... Attributes that may only be assigned objects of a specific type, + ... or subclasses thereof. For some reason we care about their order. + ... \"\"\" + ... + ... _class_attribute_ = 'typed_attributes' + ... _name_attribute_ = 'name' + ... # A default name so that instances not attached to a class can + ... # still be repr'd; useful for debugging + ... name = '' + ... + ... def __init__(self, type): + ... # Make sure not to forget to call the super __init__ + ... super(TypedAttribute, self).__init__() + ... self.type = type + ... + ... def __get__(self, obj, objtype=None): + ... if obj is None: + ... return self + ... if self.name in obj.__dict__: + ... return obj.__dict__[self.name] + ... else: + ... raise AttributeError(self.name) + ... + ... def __set__(self, obj, value): + ... if not isinstance(value, self.type): + ... raise ValueError('{0}.{1} must be of type {2!r}'.format( + ... obj.__class__.__name__, self.name, self.type)) + ... obj.__dict__[self.name] = value + ... + ... def __delete__(self, obj): + ... if self.name in obj.__dict__: + ... del obj.__dict__[self.name] + ... else: + ... raise AttributeError(self.name) + ... + ... def __repr__(self): + ... if isinstance(self.type, tuple) and len(self.type) > 1: + ... typestr = '({0})'.format( + ... ', '.join(t.__name__ for t in self.type)) + ... else: + ... typestr = self.type.__name__ + ... return '<{0}(name={1}, type={2})>'.format( + ... self.__class__.__name__, self.name, typestr) + ... + + Now let's create an example class that uses this ``TypedAttribute``:: + + >>> @six.add_metaclass(OrderedDescriptorContainer) + ... class Point2D(object): + ... x = TypedAttribute((float, int)) + ... y = TypedAttribute((float, int)) + ... + ... def __init__(self, x, y): + ... self.x, self.y = x, y + ... + >>> p1 = Point2D(1.0, 2.0) + >>> p1.x + 1.0 + >>> p1.y + 2.0 + >>> p2 = Point2D('a', 'b') # doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + ... + ValueError: Point2D.x must be of type (float, int>) + + We see that ``TypedAttribute`` works more or less as advertised, but + there's nothing special about that. Let's see what + `OrderedDescriptorContainer` did for us:: + + >>> Point2D.typed_attributes + OrderedDict([('x', ), + ('y', )]) + + If we create a subclass, it does *not* by default add inherited descriptors + to ``typed_attributes``:: + + >>> class Point3D(Point2D): + ... z = TypedAttribute((float, int)) + ... + >>> Point3D.typed_attributes + OrderedDict([('z', )]) + + However, if we specify ``_inherit_descriptors_`` from ``Point2D`` then + it will do so:: + + >>> class Point3D(Point2D): + ... _inherit_descriptors_ = (TypedAttribute,) + ... z = TypedAttribute((float, int)) + ... + >>> Point3D.typed_attributes + OrderedDict([('x', ), + ('y', ), + ('z', )]) + + .. note:: + + Hopefully it is clear from these examples that this construction + also allows a class of type `OrderedDescriptorContainer` to use + multiple different `OrderedDescriptor` classes simultaneously. + """ + + _inherit_descriptors_ = () + + def __init__(cls, cls_name, bases, members): + descriptors = defaultdict(list) + seen = set() + inherit_descriptors = () + descr_bases = {} + + for mro_cls in cls.__mro__: + for name, obj in mro_cls.__dict__.items(): + if name in seen: + # Checks if we've already seen an attribute of the given + # name (if so it will override anything of the same name in + # any base class) + continue + + seen.add(name) + + if (not isinstance(obj, OrderedDescriptor) or + (inherit_descriptors and + not isinstance(obj, inherit_descriptors))): + # The second condition applies when checking any + # subclasses, to see if we can inherit any descriptors of + # the given type from subclasses (by default inheritance is + # disabled unless the class has _inherit_descriptors_ + # defined) + continue + + if obj._name_attribute_ is not None: + setattr(obj, obj._name_attribute_, name) + + # Don't just use the descriptor's class directly; instead go + # through its MRO and find the class on which _class_attribute_ + # is defined directly. This way subclasses of some + # OrderedDescriptor *may* override _class_attribute_ and have + # its own _class_attribute_, but by default all subclasses of + # some OrderedDescriptor are still grouped together + # TODO: It might be worth clarifying this in the docs + if obj.__class__ not in descr_bases: + for obj_cls_base in obj.__class__.__mro__: + if '_class_attribute_' in obj_cls_base.__dict__: + descr_bases[obj.__class__] = obj_cls_base + descriptors[obj_cls_base].append((obj, name)) + break + else: + # Make sure to put obj first for sorting purposes + obj_cls_base = descr_bases[obj.__class__] + descriptors[obj_cls_base].append((obj, name)) + + if not (isinstance(mro_cls, type(cls)) and + mro_cls._inherit_descriptors_): + # If _inherit_descriptors_ is undefined then we don't inherit + # any OrderedDescriptors from any of the base classes, and + # there's no reason to continue through the MRO + break + else: + inherit_descriptors = mro_cls._inherit_descriptors_ + + for descriptor_cls, instances in descriptors.items(): + instances.sort() + instances = OrderedDict((key, value) for value, key in instances) + setattr(cls, descriptor_cls._class_attribute_, instances) + + super(OrderedDescriptorContainer, cls).__init__(cls_name, bases, + members) diff --git a/astropy/utils/setup_package.py b/astropy/utils/setup_package.py index 5254835af4fe..7260ef9337ba 100644 --- a/astropy/utils/setup_package.py +++ b/astropy/utils/setup_package.py @@ -21,9 +21,10 @@ def get_package_data(): 'data/test_package/*.py', 'data/test_package/data/*.txt', 'data/*.dat', - 'data/*.dat.gz', - 'data/*.dat.bz2', 'data/*.txt', + 'data/*.gz', + 'data/*.bz2', + 'data/*.xz', 'data/.hidden_file.txt', 'data/*.cfg'], 'astropy.utils.iers': [ diff --git a/astropy/utils/tests/data/local.dat.xz b/astropy/utils/tests/data/local.dat.xz new file mode 100644 index 000000000000..481dbd2cfdbd Binary files /dev/null and b/astropy/utils/tests/data/local.dat.xz differ diff --git a/astropy/utils/tests/data/unicode.txt.bz2 b/astropy/utils/tests/data/unicode.txt.bz2 new file mode 100644 index 000000000000..a201846b7b9c Binary files /dev/null and b/astropy/utils/tests/data/unicode.txt.bz2 differ diff --git a/astropy/utils/tests/data/unicode.txt.gz b/astropy/utils/tests/data/unicode.txt.gz new file mode 100644 index 000000000000..8126d6013d68 Binary files /dev/null and b/astropy/utils/tests/data/unicode.txt.gz differ diff --git a/astropy/utils/tests/data/unicode.txt.xz b/astropy/utils/tests/data/unicode.txt.xz new file mode 100644 index 000000000000..262d1c8c3ade Binary files /dev/null and b/astropy/utils/tests/data/unicode.txt.xz differ diff --git a/astropy/utils/tests/test_data.py b/astropy/utils/tests/test_data.py index 6a5f2254e36c..7386b307fde8 100644 --- a/astropy/utils/tests/test_data.py +++ b/astropy/utils/tests/test_data.py @@ -29,6 +29,15 @@ else: HAS_BZ2 = True +try: + if sys.version_info >= (3,3,0): + import lzma + else: + from backports import lzma +except ImportError: + HAS_XZ = False +else: + HAS_XZ = True @remote_data def test_download_nocache(): @@ -98,33 +107,34 @@ def test_find_by_hash(): # Package data functions -@pytest.mark.parametrize(('filename'), ['local.dat', 'local.dat.gz', 'local.dat.bz2']) +@pytest.mark.parametrize(('filename'), ['local.dat', 'local.dat.gz', 'local.dat.bz2', 'local.dat.xz']) def test_local_data_obj(filename): from ..data import get_pkg_data_fileobj - try: + if (not HAS_BZ2 and 'bz2' in filename) or (not HAS_XZ and 'xz' in filename): + with pytest.raises(ValueError) as e: + with get_pkg_data_fileobj(os.path.join('data', filename), encoding='binary') as f: + f.readline() + # assert f.read().rstrip() == b'CONTENT' + assert ' format files are not supported' in str(e) + else: with get_pkg_data_fileobj(os.path.join('data', filename), encoding='binary') as f: f.readline() assert f.read().rstrip() == b'CONTENT' - except ValueError: - if not HAS_BZ2 and 'bz2' in filename: - pass - else: - raise @pytest.mark.parametrize(('filename'), ['invalid.dat.gz', 'invalid.dat.bz2']) def test_local_data_obj_invalid(filename): from ..data import get_pkg_data_fileobj - try: + if (not HAS_BZ2 and 'bz2' in filename) or (not HAS_XZ and 'xz' in filename): + with pytest.raises(ValueError) as e: + with get_pkg_data_fileobj(os.path.join('data', filename), encoding='binary') as f: + f.read() + assert ' format files are not supported' in str(e) + else: with get_pkg_data_fileobj(os.path.join('data', filename), encoding='binary') as f: assert f.read().rstrip().endswith(b'invalid') - except ValueError: - if not HAS_BZ2 and 'bz2' in filename: - pass - else: - raise def test_local_data_name(): @@ -282,15 +292,20 @@ def osraiser(dirnm, linkto): assert not os.path.isdir(lockdir), 'Cache dir lock was not released!' -def test_read_unicode(): +@pytest.mark.parametrize(('filename'), [ + 'unicode.txt', + 'unicode.txt.gz', + pytest.mark.xfail(not HAS_BZ2, reason='no bz2 support')('unicode.txt.bz2'), + pytest.mark.xfail(not HAS_XZ, reason='no lzma support')('unicode.txt.xz') ]) +def test_read_unicode(filename): from ..data import get_pkg_data_contents - contents = get_pkg_data_contents('data/unicode.txt', encoding='utf-8') + contents = get_pkg_data_contents(os.path.join('data', filename), encoding='utf-8') assert isinstance(contents, six.text_type) contents = contents.splitlines()[1] assert contents == "האסטרונומי פייתון" - contents = get_pkg_data_contents('data/unicode.txt', encoding='binary') + contents = get_pkg_data_contents(os.path.join('data', filename), encoding='binary') assert isinstance(contents, bytes) x = contents.splitlines()[1] assert x == b"\xff\xd7\x94\xd7\x90\xd7\xa1\xd7\x98\xd7\xa8\xd7\x95\xd7\xa0\xd7\x95\xd7\x9e\xd7\x99 \xd7\xa4\xd7\x99\xd7\x99\xd7\xaa\xd7\x95\xd7\x9f"[1:] diff --git a/docs/install.rst b/docs/install.rst index eaa638827e27..4f173ab666a6 100644 --- a/docs/install.rst +++ b/docs/install.rst @@ -337,6 +337,8 @@ packages: and most affiliated packages include this as a submodule in the source repository, so it does not need to be installed separately.) + - `WCSAxes `_ + .. note:: Sphinx also requires a reasonably modern LaTeX installation to render diff --git a/docs/io/unified.rst b/docs/io/unified.rst index bae0eb66f6b0..af95507aee06 100644 --- a/docs/io/unified.rst +++ b/docs/io/unified.rst @@ -49,6 +49,11 @@ Similarly, for writing, the format can be explicitly specified:: As for the :meth:`~astropy.table.Table.read` method, the format may be automatically identified in some cases. +The underlying file handler will also automatically detect various +compressed data formats and transparently uncompress them as far as +supported by the Python installation (see +:meth:`~astropy.utils.data.get_readable_fileobj`). + Any additional arguments specified will depend on the format. For examples of this see the section `Built-in table readers/writers`_. This section also provides the full list of choices for the ``format`` argument. diff --git a/docs/modeling/bounding-boxes.rst b/docs/modeling/bounding-boxes.rst index 40008015fcb0..7fd82f897d01 100644 --- a/docs/modeling/bounding-boxes.rst +++ b/docs/modeling/bounding-boxes.rst @@ -5,7 +5,8 @@ Efficient Model Rendering with Bounding Boxes .. versionadded:: 1.1 -All `astropy.modeling.Model` subclasses have a ``bounding_box`` attribute that +All `Model ` subclasses have a +`bounding_box ` attribute that can be used to set the limits over which the model is significant. This greatly improves the effciency of evaluation when the input range is much larger than the characteristic width of the model itself. For example, to create a sky model @@ -13,137 +14,138 @@ image from a large survey catalog, each source should only be evaluated over the pixels to which it contributes a significant amount of flux. This task can otherwise be computationally prohibitive on an average CPU. -The `astropy.modeling.render_model` function can be used to evaluate a model on -an input array, or coordinates, limiting the evaluation to the ``bounding_box`` -region if it is set. This function will also produce postage stamp images of the -model if no other input array is passed. To instead extract postage -stamps from the data array itself, see :ref:`cutout_images`. +The :func:`Model.render ` method can be used to +evaluate a model on an output array, or input coordinate arrays, limiting the +evaluation to the `bounding_box ` region if +it is set. This function will also produce postage stamp images of the model if +no other input array is passed. To instead extract postage stamps from the data +array itself, see :ref:`cutout_images`. Using the Bounding Box ----------------------- -For basic usage, see `astropy.modeling.Model.bounding_box`. -By default no bounding box is set (``bounding_box`` is `None`), except for -individual model subclasses that have a ``bounding_box_default`` function -defined. ``bounding_box_default`` returns the minimum rectangular region -symmetric about the position that fully contains the model if the model has a -finite extent. If a model does not have a finite extent, the choice for the -``bounding_box_default`` limits is noted in the docstring. For example, see -`astropy.modeling.functional_models.Gaussian2D.bounding_box_default`. - -The default function can also be set to any callable. This is particularly -useful for fitting ``custom_model`` or ``CompoundModel`` instances. +For basic usage, see `Model.bounding_box `. +By default no `bounding_box ` is set +(:func:`Model.bounding_box_default ` +returns `None`), except for model subclasses where :func:`bounding_box_default +` is explicity defined. The default +is then the minimum rectangular region symmetric about the position that fully +contains the model. If the model does not have a finite extent, the containment +criteria are noted in the documentation. For example, see +`Gaussian2D.bounding_box_default +`. + +`Model.bounding_box_default ` can +be set by the user to any callable. This is particularly useful for fitting +``custom_model`` or ``CompoundModel`` instances. >>> from astropy.modeling import custom_model >>> def ellipsoid(x, y, z, x0=0, y0=0, z0=0, a=2, b=3, c=4, amp=1): - ... rsq = ((x-x0)/a) ** 2 + ((y-y0)/b) ** 2 + ((z-z0)/c) ** 2 + ... rsq = ((x - x0) / a) ** 2 + ((y - y0) / b) ** 2 + ((z - z0) / c) ** 2 ... val = (rsq < 1) * amp ... return val ... - >>> def ellipsoid_bb(self): - ... return ((self.z0 - self.c, self.z0 + self.c), - ... (self.y0 - self.b, self.y0 + self.b), - ... (self.x0 - self.a, self.x0 + self.a)) + >>> class Ellipsoid3D(custom_model(ellipsoid)): + ... # A 3D ellipsoid model + ... def bounding_box_default(self): + ... return ((self.z0 - self.c, self.z0 + self.c), + ... (self.y0 - self.b, self.y0 + self.b), + ... (self.x0 - self.a, self.x0 + self.a)) ... - >>> Ellipsoid3D = custom_model(ellipsoid) - >>> Ellipsoid3D.bounding_box_default = ellipsoid_bb >>> model = Ellipsoid3D() - >>> model.bounding_box = 'auto' >>> model.bounding_box ((-4.0, 4.0), (-3.0, 3.0), (-2.0, 2.0)) -Efficient evaluation with ``render_model`` ------------------------------------------- +Efficient evaluation with :func:`Model.render() ` +-------------------------------------------------------------------------------- When a model is evaluated over a range much larger than the model itself, it may -be prudent to use `astropy.modeling.render_model` if efficiency is a concern. -The ``render_model`` function can be used to evaluate a model on an array of the -same dimensions. If no array is given, ``render_model`` will return a "postage -stamp" array corresponding to the bounding box region. However, if -``bounding_box`` is `None` an image or coordinates must be passed. +be prudent to use the :func:`Model.render ` +method if efficiency is a concern. The :func:`render ` +method can be used to evaluate the model on an array of the same dimensions. +``model.render()`` can be called with no arguments to return a "postage +stamp" of the bounding box region. In this example, we generate a 300x400 pixel image of 100 2D -Gaussian sources both with and without using bounding boxes. Using bounding -boxes, the evaluation speed increases by approximately a factor of 10 with -negligible loss of information. +Gaussian sources. For comparison, the models are evaluated +both with and without using bounding boxes. By using bounding boxes, the evaluation +speed increases by approximately a factor of 10 with negligible loss of information. .. plot:: :include-source: import numpy as np from time import time - from astropy.modeling import models, render_model - - import matplotlib as mpl + from astropy.modeling import models import matplotlib.pyplot as plt - from astropy.visualization import astropy_mpl_style - astropy_mpl_style['axes.grid'] = False - astropy_mpl_style['axes.labelcolor'] = 'k' - mpl.rcParams.update(astropy_mpl_style) - - np.random.seed(0) + from matplotlib.patches import Rectangle imshape = (300, 400) - nsrc = 100 + y, x = np.indices(imshape) + # Generate random source model list + np.random.seed(0) + nsrc = 100 model_params = [ - dict(amplitude = np.random.uniform(0, 1), - x_mean = np.random.uniform(0, imshape[1]), - y_mean = np.random.uniform(0, imshape[0]), - x_stddev = np.abs(np.random.uniform(3, 6)), - y_stddev = np.abs(np.random.uniform(3, 6)), - theta = np.random.uniform(0, 2 * np.pi)) - for i in range(nsrc)] + dict(amplitude=np.random.uniform(.5, 1), + x_mean=np.random.uniform(0, imshape[1] - 1), + y_mean=np.random.uniform(0, imshape[0] - 1), + x_stddev=np.random.uniform(2, 6), + y_stddev=np.random.uniform(2, 6), + theta=np.random.uniform(0, 2 * np.pi)) + for _ in range(nsrc)] model_list = [models.Gaussian2D(**kwargs) for kwargs in model_params] - #Evaluate all models over their bounded regions and over the full image - #for comparison. - - def make_image(model_list, shape=imshape, mode='bbox'): - image = np.zeros(imshape) - t1 = time() - for i,model in enumerate(model_list): - if mode == 'full': model.bounding_box = None - elif mode == 'auto': model.bounding_box = 'auto' - image = render_model(model, image) - t2 = time() - return image, (t2 - t1) + # Render models to image using bounding boxes + bb_image = np.zeros(imshape) + t_bb = time() + for model in model_list: + model.render(bb_image) + t_bb = time() - t_bb - bb_image, t_bb = make_image(model_list, mode='auto') - full_image, t_full = make_image(model_list, mode='full') + # Render models to image using full evaluation + full_image = np.zeros(imshape) + t_full = time() + for model in model_list: + model.bounding_box = None + model.render(full_image) + t_full = time() - t_full flux = full_image.sum() diff = (full_image - bb_image) max_err = diff.max() + # Plots plt.figure(figsize=(16, 7)) - plt.subplots_adjust(left=.05,right=.97,bottom=.03,top=.97,wspace=0.1)#07) + plt.subplots_adjust(left=.05, right=.97, bottom=.03, top=.97, wspace=0.15) + # Full model image plt.subplot(121) plt.imshow(full_image, origin='lower') - plt.axis([0,imshape[1],0,imshape[0]]) - plt.title('Full Models\nTiming: %.2f seconds' % (t_full), fontsize=16) - plt.xlabel('x', fontsize=14) - plt.ylabel('y', fontsize=14) + plt.title('Full Models\nTiming: {:.2f} seconds'.format(t_full), fontsize=16) + plt.xlabel('x') + plt.ylabel('y') - plt.subplot(122) + # Bounded model image with boxes overplotted + ax = plt.subplot(122) plt.imshow(bb_image, origin='lower') for model in model_list: - y1,y2,x1,x2 = np.reshape(model.bounding_box_default(),(4,)) - plt.plot([x1,x2,x2,x1,x1], [y1,y1,y2,y2,y1], 'w-',alpha=.2) - - plt.axis([0,imshape[1],0,imshape[0]]) - plt.title('Bounded Models\nTiming: %.2f seconds' % (t_bb), fontsize=16) - plt.xlabel('x', fontsize=14) - plt.ylabel('y', fontsize=14) - - plt.figure(figsize=(16,8)) + dy, dx = np.diff(model.bounding_box_default()).flatten() + pos = (model.x_mean.value - dx / 2, model.y_mean.value - dy / 2) + r = Rectangle(pos, dx, dy, edgecolor='w', facecolor='none', alpha=.25) + ax.add_patch(r) + plt.title('Bounded Models\nTiming: {:.2f} seconds'.format(t_bb), fontsize=16) + plt.xlabel('x') + plt.ylabel('y') + + # Difference image + plt.figure(figsize=(16, 8)) + plt.subplot(111) plt.imshow(diff, vmin=-max_err, vmax=max_err) plt.colorbar(format='%.1e') - plt.title('Difference Image\nTotal Flux Err = %.0e' - %((flux - np.sum(bb_image)) / flux), fontsize=16) - plt.xlabel('x', fontsize=14) - plt.ylabel('y', fontsize=14) + plt.title('Difference Image\nTotal Flux Err = {:.0e}'.format( + ((flux - np.sum(bb_image)) / flux))) + plt.xlabel('x') + plt.ylabel('y') plt.show() - diff --git a/docs/nddata/utils.rst b/docs/nddata/utils.rst index b9c996aadd4c..5419737b165a 100644 --- a/docs/nddata/utils.rst +++ b/docs/nddata/utils.rst @@ -53,8 +53,9 @@ We create a cutout array centered at position ``(x, y) = (49.7, 100.1)`` with a shape of ``(ny, nx) = (40, 50)``:: >>> from astropy.nddata import Cutout2D + >>> from astropy import units as u >>> position = (49.7, 100.1) - >>> shape = (40, 50) + >>> shape = (40*u.pixel, 50*u.pixel) >>> cutout = Cutout2D(data, position, shape) The cutout array is stored in the ``data`` attribute of the @@ -134,6 +135,29 @@ including:: >>> print(cutout.slices_cutout) (slice(0, 40, None), slice(0, 50, None)) +Cutouts don't have to be specified by their shape if they are square. +Let's create another cutout array centered at position ``(x, y) = (49.7, +100.1)``, but this time with a square cutout that is 55 pixels to a side:: + + >>> side_length = 55*u.pixel + >>> cutout2 = Cutout2D(data, position, side_length=side_length) + +.. doctest-skip:: + + >>> plt.imshow(cutout2.data, origin='lower') + +.. plot:: + + import numpy as np + import matplotlib.pyplot as plt + from astropy.modeling.models import Gaussian2D + from astropy.nddata import Cutout2D + y, x = np.mgrid[0:500, 0:500] + data = Gaussian2D(1, 50, 100, 10, 5, theta=0.5)(x, y) + position = (49.7, 100.1) + cutout = Cutout2D(data, position, side_length=55) + plt.imshow(cutout.data, origin='lower') + There are also two `~astropy.nddata.utils.Cutout2D` methods to convert pixel positions between the original and cutout arrays:: diff --git a/docs/rtd-pip-requirements b/docs/rtd-pip-requirements index 1d066a713e7e..e405b4447faa 100644 --- a/docs/rtd-pip-requirements +++ b/docs/rtd-pip-requirements @@ -2,3 +2,4 @@ numpy>=1.6.0 matplotlib Cython +wcsaxes diff --git a/docs/wcs/index.rst b/docs/wcs/index.rst index f1659d258489..ff50c6758215 100644 --- a/docs/wcs/index.rst +++ b/docs/wcs/index.rst @@ -235,6 +235,26 @@ ability to use the :class:`~astropy.wcs.WCS` to define projections in Matplotlib. More information on installing and using WCSAxes can be found `here `__. +.. plot:: + :include-source: + + from matplotlib import pyplot as plt + from astropy.io import fits + from astropy.wcs import WCS + from astropy.utils.data import download_file + + fits_file = 'http://data.astropy.org/tutorials/FITS-images/HorseHead.fits' + image_file = download_file(fits_file, cache=True ) + hdu = fits.open(image_file)[0] + wcs = WCS(hdu.header) + + fig = plt.figure() + fig.add_subplot(111, projection=wcs) + plt.imshow(hdu.data, origin='lower', cmap='cubehelix') + plt.xlabel('RA') + plt.ylabel('Dec') + plt.show() + Other information =================