diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 2c855bf38..9b7e6009d 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -12,7 +12,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: [3.6] + python-version: [3.6, 3.7, 3.8, 3.9] steps: - uses: actions/checkout@v2 diff --git a/docs/requirements.txt b/docs/requirements.txt index 9e039be4a..4c67cdac7 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -1,6 +1,6 @@ multipledispatch numpy>=1.7 opt_einsum>=2.3.2 -unification pytest>=4.1 makefun +typing_extensions diff --git a/docs/source/index.rst b/docs/source/index.rst index 230a68b5d..6b45a5131 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -21,6 +21,7 @@ Funsor is a tensor-like library for functions and distributions affine factory testing + typing .. toctree:: :glob: diff --git a/docs/source/typing.rst b/docs/source/typing.rst new file mode 100644 index 000000000..77ba6fabe --- /dev/null +++ b/docs/source/typing.rst @@ -0,0 +1,8 @@ +Typing Utiltites +---------------------- +.. automodule:: funsor.typing + :members: + :undoc-members: + :show-inheritance: + :member-order: bysource + diff --git a/funsor/cnf.py b/funsor/cnf.py index 2200a6e74..b81be57e4 100644 --- a/funsor/cnf.py +++ b/funsor/cnf.py @@ -8,7 +8,6 @@ from typing import Tuple, Union import opt_einsum -from multipledispatch.variadic import Variadic import funsor import funsor.ops as ops @@ -32,6 +31,7 @@ Variable, to_funsor, ) +from funsor.typing import Variadic from funsor.util import broadcast_shape, get_backend, quote diff --git a/funsor/joint.py b/funsor/joint.py index cdc2a092c..16dea97ca 100644 --- a/funsor/joint.py +++ b/funsor/joint.py @@ -7,7 +7,6 @@ from typing import Tuple, Union from multipledispatch import dispatch -from multipledispatch.variadic import Variadic import funsor.ops as ops from funsor.cnf import Contraction, GaussianMixture @@ -18,6 +17,7 @@ from funsor.ops import AssociativeOp from funsor.tensor import Tensor, align_tensor from funsor.terms import Funsor, Independent, Number, Reduce, Unary +from funsor.typing import Variadic @dispatch(str, str, Variadic[(Gaussian, GaussianMixture)]) diff --git a/funsor/optimizer.py b/funsor/optimizer.py index 3a56e7925..91cc06763 100644 --- a/funsor/optimizer.py +++ b/funsor/optimizer.py @@ -3,7 +3,6 @@ import collections -from multipledispatch.variadic import Variadic from opt_einsum.paths import greedy import funsor.interpreter as interpreter @@ -18,6 +17,7 @@ from funsor.interpreter import get_interpretation from funsor.ops import DISTRIBUTIVE_OPS, AssociativeOp from funsor.terms import Funsor +from funsor.typing import Variadic unfold_base = DispatchedInterpretation() unfold = PrioritizedInterpretation(unfold_base, normalize_base, lazy) diff --git a/funsor/registry.py b/funsor/registry.py index 07f9f5542..389c5dadd 100644 --- a/funsor/registry.py +++ b/funsor/registry.py @@ -3,8 +3,9 @@ from collections import defaultdict -from multipledispatch import Dispatcher -from multipledispatch.conflict import supercedes +from multipledispatch.dispatcher import Dispatcher, expand_tuples + +from funsor.typing import Variadic, deep_type, get_origin, get_type_hints, typing_wrap class PartialDispatcher(Dispatcher): @@ -12,11 +13,40 @@ class PartialDispatcher(Dispatcher): Wrapper to avoid appearance in stack traces. """ + def __init__(self, default=None): + self.default = default if default is None else PartialDefault(default) + super().__init__("PartialDispatcher") + if default is not None: + self.add(([object],), self.default) + + def add(self, signature, func): + + # Handle annotations + if not signature: + annotations = get_type_hints(func) + annotations.pop("return", None) + if annotations: + signature = tuple(annotations.values()) + + # Handle some union types by expanding at registration time + if any(isinstance(typ, tuple) for typ in signature): + for typs in expand_tuples(signature): + self.add(typs, func) + return + + # Handle variadic types + signature = ( + Variadic[tuple(tp)] if isinstance(tp, list) else tp for tp in signature + ) + + signature = tuple(map(typing_wrap, signature)) + super().add(signature, func) + def partial_call(self, *args): """ Likde :meth:`__call__` but avoids calling ``func()``. """ - types = tuple(map(type, args)) + types = tuple(map(typing_wrap, map(deep_type, args))) try: func = self._cache[types] except KeyError: @@ -29,6 +59,9 @@ def partial_call(self, *args): self._cache[types] = func return func + def __call__(self, *args): + return self.partial_call(*args)(*args) + class PartialDefault: def __init__(self, default): @@ -44,20 +77,12 @@ def partial_call(self, *args): class KeyedRegistry(object): def __init__(self, default=None): - self.default = default if default is None else PartialDefault(default) # TODO make registry a WeakKeyDictionary - self.registry = defaultdict(lambda: PartialDispatcher("f")) + self.default = default if default is None else PartialDefault(default) + self.registry = defaultdict(lambda: PartialDispatcher(default=default)) def register(self, key, *types): - key = getattr(key, "__origin__", key) - register = self.registry[key].register - if self.default: - objects = (object,) * len(types) - try: - if objects != types and supercedes(types, objects): - register(*objects)(self.default) - except TypeError: - pass # mysterious source of ambiguity in Python 3.5 breaks this + register = self.registry[get_origin(key)].register # This decorator supports stacking multiple decorators, which is not # supported by multipledipatch (which returns a Dispatch object rather @@ -69,10 +94,10 @@ def decorator(fn): return decorator def __contains__(self, key): - return key in self.registry + return get_origin(key) in self.registry def __getitem__(self, key): - key = getattr(key, "__origin__", key) + key = get_origin(key) if self.default is None: return self.registry[key] return self.registry.get(key, self.default) diff --git a/funsor/tensor.py b/funsor/tensor.py index 3a7fcabe2..9ac05ac4d 100644 --- a/funsor/tensor.py +++ b/funsor/tensor.py @@ -12,7 +12,6 @@ import numpy as np import opt_einsum from multipledispatch import dispatch -from multipledispatch.variadic import Variadic import funsor @@ -35,6 +34,7 @@ to_data, to_funsor, ) +from .typing import Variadic from .util import get_backend, get_tracing_state, getargspec, is_nn_module, quote diff --git a/funsor/terms.py b/funsor/terms.py index 3b13ced73..fa9c78288 100644 --- a/funsor/terms.py +++ b/funsor/terms.py @@ -12,7 +12,6 @@ from weakref import WeakValueDictionary from multipledispatch import dispatch -from multipledispatch.variadic import Variadic, isvariadic import funsor.interpreter as interpreter import funsor.ops as ops @@ -29,6 +28,7 @@ from funsor.interpreter import PatternMissingError, interpret from funsor.ops import AssociativeOp, GetitemOp, Op from funsor.syntax import INFIX_OPERATORS, PREFIX_OPERATORS +from funsor.typing import GenericTypeMeta, Variadic, deep_type, get_origin from funsor.util import getargspec, lazy_property, pretty, quote from . import instrument, interpreter, ops @@ -105,15 +105,8 @@ def reflect(cls, *args, **kwargs): if cache_key in cls._cons_cache: return cls._cons_cache[cache_key] - arg_types = tuple( - typing.Tuple[tuple(map(type, arg))] - if (type(arg) is tuple and all(isinstance(a, Funsor) for a in arg)) - else typing.Tuple - if (type(arg) is tuple and not arg) - else type(arg) - for arg in args - ) - cls_specific = (cls.__origin__ if cls.__args__ else cls)[arg_types] + arg_types = tuple(map(deep_type, args)) + cls_specific = get_origin(cls)[arg_types] result = super(FunsorMeta, cls_specific).__call__(*args) result._ast_values = args @@ -121,7 +114,7 @@ def reflect(cls, *args, **kwargs): size, depth, width = _get_ast_stats(result) instrument.COUNTERS["ast_size"][size] += 1 instrument.COUNTERS["ast_depth"][depth] += 1 - classname = getattr(cls, "__origin__", cls).__name__ + classname = get_origin(cls).__name__ instrument.COUNTERS["funsor"][classname] += 1 instrument.COUNTERS[classname][width] += 1 @@ -132,7 +125,7 @@ def reflect(cls, *args, **kwargs): return result -class FunsorMeta(type): +class FunsorMeta(GenericTypeMeta): """ Metaclass for Funsors to perform four independent tasks: @@ -157,15 +150,17 @@ class FunsorMeta(type): def __init__(cls, name, bases, dct): super(FunsorMeta, cls).__init__(name, bases, dct) - if not hasattr(cls, "__args__"): - cls.__args__ = () - if cls.__args__: - (base,) = bases - cls.__origin__ = base - else: + if not cls.__args__: cls._ast_fields = getargspec(cls.__init__)[0][1:] cls._cons_cache = WeakValueDictionary() - cls._type_cache = WeakValueDictionary() + + def __getitem__(cls, arg_types): + if not isinstance(arg_types, tuple): + arg_types = (arg_types,) + assert len(arg_types) == len( + cls._ast_fields + ), "Must provide exactly one type per subexpression" + return super().__getitem__(arg_types) def __call__(cls, *args, **kwargs): if cls.__args__: @@ -181,104 +176,9 @@ def __call__(cls, *args, **kwargs): return interpret(cls, *args) - def __getitem__(cls, arg_types): - if not isinstance(arg_types, tuple): - arg_types = (arg_types,) - assert not any( - isvariadic(arg_type) for arg_type in arg_types - ), "nested variadic types not supported" - # switch tuple to typing.Tuple - arg_types = tuple( - typing.Tuple if arg_type is tuple else arg_type for arg_type in arg_types - ) - if arg_types not in cls._type_cache: - assert not cls.__args__, "cannot subscript a subscripted type {}".format( - cls - ) - assert len(arg_types) == len( - cls._ast_fields - ), "must provide types for all params" - new_dct = cls.__dict__.copy() - new_dct.update({"__args__": arg_types}) - # type(cls) to handle FunsorMeta subclasses - cls._type_cache[arg_types] = type(cls)(cls.__name__, (cls,), new_dct) - return cls._type_cache[arg_types] - - def __subclasscheck__(cls, subcls): # issubclass(subcls, cls) - if cls is subcls: - return True - if not isinstance(subcls, FunsorMeta): - return super(FunsorMeta, getattr(cls, "__origin__", cls)).__subclasscheck__( - subcls - ) - - cls_origin = getattr(cls, "__origin__", cls) - subcls_origin = getattr(subcls, "__origin__", subcls) - if not super(FunsorMeta, cls_origin).__subclasscheck__(subcls_origin): - return False - - if cls.__args__: - if not subcls.__args__: - return False - if len(cls.__args__) != len(subcls.__args__): - return False - for subcls_param, param in zip(subcls.__args__, cls.__args__): - if not _issubclass_tuple(subcls_param, param): - return False - return True - @lazy_property def classname(cls): - return cls.__name__ + "[{}]".format( - ", ".join( - str(getattr(t, "classname", t)) # Tuple doesn't have __name__ - for t in cls.__args__ - ) - ) - - -def _issubclass_tuple(subcls, cls): - """ - utility for pattern matching with tuple subexpressions - """ - # so much boilerplate... - cls_is_union = ( - hasattr(cls, "__origin__") and (cls.__origin__ or cls) is typing.Union - ) - if isinstance(cls, tuple) or cls_is_union: - return any( - _issubclass_tuple(subcls, option) - for option in (getattr(cls, "__args__", []) if cls_is_union else cls) - ) - - subcls_is_union = ( - hasattr(subcls, "__origin__") and (subcls.__origin__ or subcls) is typing.Union - ) - if isinstance(subcls, tuple) or subcls_is_union: - return any( - _issubclass_tuple(option, cls) - for option in ( - getattr(subcls, "__args__", []) if subcls_is_union else subcls - ) - ) - - subcls_is_tuple = hasattr(subcls, "__origin__") and ( - subcls.__origin__ or subcls - ) in (tuple, typing.Tuple) - cls_is_tuple = hasattr(cls, "__origin__") and (cls.__origin__ or cls) in ( - tuple, - typing.Tuple, - ) - if subcls_is_tuple != cls_is_tuple: - return False - if not cls_is_tuple: - return issubclass(subcls, cls) - if not cls.__args__: - return True - if not subcls.__args__ or len(subcls.__args__) != len(cls.__args__): - return False - - return all(_issubclass_tuple(a, b) for a, b in zip(subcls.__args__, cls.__args__)) + return repr(cls) def _convert_reduced_vars(reduced_vars, inputs): diff --git a/funsor/typing.py b/funsor/typing.py new file mode 100644 index 000000000..38e9adf2c --- /dev/null +++ b/funsor/typing.py @@ -0,0 +1,367 @@ +# Copyright Contributors to the Pyro project. +# SPDX-License-Identifier: Apache-2.0 + +import functools +import sys +import typing +import weakref + +import typing_extensions +from multipledispatch.variadic import Variadic as _OrigVariadic +from multipledispatch.variadic import isvariadic + +################################# +# Runtime type-checking helpers +################################# + + +@functools.singledispatch +def deep_type(obj): + """ + An enhanced version of :func:`type` that reconstructs structured :mod:`typing`` types + for a limited set of immutable data structures, notably ``tuple`` and ``frozenset``. + Mostly intended for internal use in Funsor interpretation pattern-matching. + + Example:: + + assert deep_type((1, ("a",))) is typing.Tuple[int, typing.Tuple[str]] + assert deep_type(frozenset(["a"])) is typing.FrozenSet[str] + """ + # compare to pytypes.deep_type(obj) + return type(obj) + + +@deep_type.register(tuple) +def _deep_type_tuple(obj): + return typing.Tuple[tuple(map(deep_type, obj))] if obj else typing.Tuple + + +@deep_type.register(frozenset) +def _deep_type_frozenset(obj): + if not obj: + return typing.FrozenSet + tp = deep_type(next(iter(obj))) + for x in obj: + if not deep_isinstance(x, tp): + tp = get_origin(tp) + if not deep_isinstance(x, tp): + raise NotImplementedError( + f"TODO handle inhomogeneous frozensets: {str(obj)}" + ) + return typing.FrozenSet[tp] + + +_subclasscheck_registry = {} + + +def register_subclasscheck(cls): + """ + Decorator for registering a custom ``__subclasscheck__`` method for ``cls`` + which is only ever invoked in :func:`deep_issubclass`. + + This is primarily intended for working with the :mod:`typing` library at runtime. + Prefer overriding ``__subclasscheck__`` in the usual way with a metaclass + where possible. + """ + + def _fn(fn): + _subclasscheck_registry[cls] = fn + return fn + + return _fn + + +@register_subclasscheck(typing.Any) +def _subclasscheck_any(cls, subcls): + return True + + +@register_subclasscheck(typing.Union) +def _subclasscheck_union(cls, subcls): + """A basic ``__subclasscheck__`` method for :class:`~typing.Union`.""" + return any(deep_issubclass(subcls, arg) for arg in get_args(cls)) + + +@register_subclasscheck(frozenset) +@register_subclasscheck(typing.FrozenSet) +def _subclasscheck_frozenset(cls, subcls): + """A basic ``__subclasscheck__`` method for :class:`~typing.FrozenSet`.""" + + if not issubclass(get_origin(subcls), frozenset): + return False + + cls_args, subcls_args = get_args(cls), get_args(subcls) + + if not cls_args: + return True + + if not subcls_args: + return cls_args[0] is typing.Any + + return len(subcls_args) == len(cls_args) == 1 and all( + deep_issubclass(a, b) for a, b in zip(subcls_args, cls_args) + ) + + +@register_subclasscheck(tuple) +@register_subclasscheck(typing.Tuple) +def _subclasscheck_tuple(cls, subcls): + """A basic ``__subclasscheck__`` method for :class:`~typing.Tuple`.""" + + if not issubclass(get_origin(subcls), get_origin(cls)): + return False + + cls_args, subcls_args = get_args(cls), get_args(subcls) + + if not cls_args: # cls is base Tuple + return True + + if not subcls_args: + return cls_args[0] is typing.Any + + if cls_args[-1] is Ellipsis: # cls variadic + if subcls_args[-1] is Ellipsis: # both variadic + return deep_issubclass(subcls_args[0], cls_args[0]) + return all(deep_issubclass(a, cls_args[0]) for a in subcls_args) + + if subcls_args[-1] is Ellipsis: # only subcls variadic + # issubclass(Tuple[A, ...], Tuple[X, Y]) == False + return False + + # neither variadic + return len(cls_args) == len(subcls_args) and all( + deep_issubclass(a, b) for a, b in zip(subcls_args, cls_args) + ) + + +@functools.lru_cache(maxsize=None) +def deep_issubclass(subcls, cls): + """ + Enhanced version of :func:`issubclass` that can handle structured types, + including Funsor terms, :class:`~typing.Tuple`, and :class:`~typing.FrozenSet`. + + Does not support more advanced :mod:`typing` features such as + :class:`~typing.TypeVar`, arbitrary :class:`~typing.Generic` subtypes, + forward references, or mutable collection types like :class:`~typing.List`. + Will attempt to fall back to :func:`issubclass` when it encounters a type in + ``subcls`` or ``cls`` that it does not understand. + + Usage:: + + class A: pass + class B(A): pass + + assert deep_issubclass(typing.Tuple[int, B], typing.Tuple[int, A]) + assert not deep_issubclass(typing.Tuple[int, A], typing.Tuple[int, B]) + + assert deep_issubclass(typing.Tuple[A, A], typing.Tuple[A, ...]) + assert not deep_issubclass(typing.Tuple[B], typing.Tuple[A, ...]) + + :param subcls: A class that may be a subclass of ``cls``. + :param cls: A class that may be a parent class of ``subcls``. + """ + # compare to pytypes.is_subtype(subcls, cls) + + # handle unpacking + if isinstance(subcls, _RuntimeSubclassCheckMeta): + try: + return deep_issubclass(subcls.__args__[0], cls) + except TypeError as e: + if e.args[0] == "issubclass() arg 1 must be a class": + return deep_issubclass(get_origin(subcls.__args__[0]), cls) + raise + + if get_origin(subcls) is typing.Union: + return all(deep_issubclass(arg, cls) for arg in get_args(subcls)) + + if subcls is typing.Any: + return cls is typing.Any + + try: + return _subclasscheck_registry[get_origin(cls)](cls, subcls) + except KeyError: + return issubclass(subcls, cls) + + +def deep_isinstance(obj, cls): + """ + Enhanced version of :func:`isinstance` that can handle basic structured :mod:`typing` types, + including Funsor terms and other :class:`~funsor.typing.GenericTypeMeta` instances, + :class:`~typing.Union`, :class:`~typing.Tuple`, and :class:`~typing.FrozenSet`. + + Does not support :class:`~typing.TypeVar`, arbitrary :class:`~typing.Generic`, + forward references, or mutable generic collection types like :class:`~typing.List`. + Will attempt to fall back to :func:`isinstance` when it encounters + an unsupported type in ``obj`` or ``cls``. + + Usage:: + + x = (1, ("a", "b")) + assert deep_isinstance(x, typing.Tuple[int, tuple]) + assert deep_isinstance(x, typing.Tuple[typing.Any, typing.Tuple[str, ...]]) + + :param obj: An object that may be an instance of ``cls``. + :param cls: A class that may be a parent class of ``obj``. + """ + + # compare to pytypes.is_of_type(obj, cls) + try: + return deep_issubclass(deep_type(obj), cls) + except TypeError: + return isinstance(obj, cls) + + +def _type_to_typing(tp): + if tp is object: + tp = typing.Any + return tp + + +############################################## +# Funsor-compatible typing introspection API +############################################## + + +def get_args(tp): + if isinstance(tp, GenericTypeMeta) or sys.version_info[:2] < (3, 7): + result = getattr(tp, "__args__", None) + else: + result = typing_extensions.get_args(tp) + return () if result is None else result + + +def get_origin(tp): + if isinstance(tp, GenericTypeMeta) or sys.version_info[:2] < (3, 7): + result = getattr(tp, "__origin__", None) + else: + result = typing_extensions.get_origin(tp) + return tp if result is None else result + + +if sys.version_info[:2] >= (3, 7): # reuse upstream documentation if possible + get_args = functools.wraps(typing_extensions.get_args)(get_args) + get_origin = functools.wraps(typing_extensions.get_origin)(get_origin) + + +@functools.wraps(typing.get_type_hints) +def get_type_hints(obj, globalns=None, localns=None, **kwargs): + return typing.get_type_hints(obj, globalns=globalns, localns=localns, **kwargs) + + +###################################################################### +# Metaclass for generating parametric types with Tuple-like variance +###################################################################### + + +class GenericTypeMeta(type): + """ + Metaclass to support subtyping with parameters for pattern matching, e.g. ``Number[int, int]``. + """ + + def __init__(cls, name, bases, dct): + super().__init__(name, bases, dct) + if not hasattr(cls, "__args__"): + cls.__args__ = () + if cls.__args__: + (base,) = bases + cls.__origin__ = base + else: + cls._type_cache = weakref.WeakValueDictionary() + + def __getitem__(cls, arg_types): + if not isinstance(arg_types, tuple): + arg_types = (arg_types,) + arg_types = tuple(map(_type_to_typing, arg_types)) + try: + return cls._type_cache[arg_types] + except KeyError: + assert not get_args(cls), "cannot subscript a subscripted type {}".format( + cls + ) + assert not any( + isvariadic(arg_type) for arg_type in arg_types + ), "nested variadic types not supported" + new_dct = cls.__dict__.copy() + new_dct.update({"__args__": arg_types}) + # type(cls) to handle GenericTypeMeta subclasses + result = type(cls)(cls.__name__, (cls,), new_dct) + cls._type_cache[arg_types] = result + return result + + def __subclasscheck__(cls, subcls): # issubclass(subcls, cls) + if cls is subcls: + return True + + cls_origin = get_origin(cls) + if not isinstance(subcls, GenericTypeMeta): + return super(GenericTypeMeta, cls_origin).__subclasscheck__(subcls) + + if not super(GenericTypeMeta, cls_origin).__subclasscheck__(get_origin(subcls)): + return False + + cls_args, subcls_args = get_args(cls), get_args(subcls) + if len(cls_args) != len(subcls_args): + return len(cls_args) == 0 + + return all( + deep_issubclass(_type_to_typing(ps), _type_to_typing(pc)) + for ps, pc in zip(subcls_args, cls_args) + ) + + def __repr__(cls): + return get_origin(cls).__name__ + ( + "" + if not get_args(cls) + else "[{}]".format(", ".join(repr(t) for t in get_args(cls))) + ) + + +############################################################## +# Tools and overrides for typing-compatible multipledispatch +############################################################## + + +class _RuntimeSubclassCheckMeta(GenericTypeMeta): + def __call__(cls, tp): + tp = _type_to_typing(tp) + return tp if isinstance(tp, GenericTypeMeta) or isvariadic(tp) else cls[tp] + + def __subclasscheck__(cls, subcls): + return deep_issubclass(subcls, cls.__args__[0]) + + +class typing_wrap(metaclass=_RuntimeSubclassCheckMeta): + """ + Utility callable for overriding the runtime behavior of :mod:`typing` objects. + """ + + pass + + +class _DeepVariadicSignatureType(type): + def __getitem__(cls, key): + if not isinstance(key, tuple): + key = (key,) + return _OrigVariadic[tuple(map(typing_wrap, key))] + + +class Variadic(metaclass=_DeepVariadicSignatureType): + """ + A typing-compatible drop-in replacement for :class:`~multipledispatch.variadic.Variadic`. + """ + + pass + + +__all__ = [ + "GenericTypeMeta", + "Variadic", + "deep_isinstance", + "deep_issubclass", + "deep_type", + "get_args", + "get_origin", + "get_type_hints", + "register_subclasscheck", + "typing_wrap", +] diff --git a/setup.py b/setup.py index d0b6de0e0..7ef9cd980 100644 --- a/setup.py +++ b/setup.py @@ -30,7 +30,13 @@ project_urls={"Documentation": "https://funsor.pyro.ai"}, author="Uber AI Labs", python_requires=">=3.6", - install_requires=["makefun", "multipledispatch", "numpy>=1.7", "opt_einsum>=2.3.2"], + install_requires=[ + "makefun", + "multipledispatch", + "numpy>=1.7", + "opt_einsum>=2.3.2", + "typing_extensions", + ], extras_require={ "torch": ["pyro-ppl>=1.5.2", "torch>=1.7.0"], "jax": ["numpyro>=0.2.4", "jax>=0.1.57", "jaxlib>=0.1.37"], diff --git a/test/test_cnf.py b/test/test_cnf.py index b7fde5907..8a85486bd 100644 --- a/test/test_cnf.py +++ b/test/test_cnf.py @@ -13,7 +13,7 @@ BACKEND_TO_LOGSUMEXP_BACKEND, Contraction, ) -from funsor.domains import Bint, Reals # noqa F403 +from funsor.domains import Array, Bint, Reals # noqa: F401 from funsor.einsum import einsum, naive_plated_einsum from funsor.interpretations import eager, normalize, reflect from funsor.interpreter import reinterpret diff --git a/test/test_terms.py b/test/test_terms.py index daa5f49a5..e495e3e32 100644 --- a/test/test_terms.py +++ b/test/test_terms.py @@ -29,7 +29,6 @@ from funsor.terms import ( Binary, Cat, - Funsor, Independent, Lambda, Number, @@ -45,6 +44,7 @@ from funsor.testing import assert_close, check_funsor, random_tensor assert Binary # flake8 +assert Reduce # flake8 assert Subs # flake8 assert Contraction # flake8 assert Reals # flake8 @@ -576,103 +576,6 @@ def test_align_simple(): assert f(x=1, y=2, z=3) == g(x=1, y=2, z=3) -@pytest.mark.parametrize( - "subcls_expr,cls_expr", - [ - ("Reduce", "Reduce"), - ("Reduce[ops.AssociativeOp, Funsor, frozenset]", "Funsor"), - ("Reduce[ops.AssociativeOp, Funsor, frozenset]", "Reduce"), - ( - "Reduce[ops.AssociativeOp, Funsor, frozenset]", - "Reduce[ops.Op, Funsor, frozenset]", - ), - ( - "Reduce[ops.AssociativeOp, Reduce[ops.AssociativeOp, Funsor, frozenset], frozenset]", - "Reduce[ops.Op, Funsor, frozenset]", - ), - ( - "Reduce[ops.AssociativeOp, Reduce[ops.AssociativeOp, Funsor, frozenset], frozenset]", - "Reduce[ops.AssociativeOp, Reduce, frozenset]", - ), - ("Stack[str, typing.Tuple[Number, Number, Number]]", "Stack"), - ("Stack[str, typing.Tuple[Number, Number, Number]]", "Stack[str, tuple]"), - # Unions - ( - "Reduce[ops.AssociativeOp, (Number, Stack[str, (tuple, typing.Tuple[Number, Number])]), frozenset]", - "Funsor", - ), - ( - "Reduce[ops.AssociativeOp, (Number, Stack), frozenset]", - "Reduce[ops.Op, Funsor, frozenset]", - ), - ( - "Reduce[ops.AssociativeOp, (Stack, Reduce[ops.AssociativeOp, (Number, Stack), frozenset]), frozenset]", - "Reduce[(ops.Op, ops.AssociativeOp), Stack, frozenset]", - ), - ], -) -def test_parametric_subclass(subcls_expr, cls_expr): - subcls = eval(subcls_expr) - cls = eval(cls_expr) - print(subcls.classname) - print(cls.classname) - assert issubclass(cls, (Funsor, Reduce)) and not issubclass( - subcls, typing.Tuple - ) # appease flake8 - assert issubclass(subcls, cls) - - -@pytest.mark.parametrize( - "subcls_expr,cls_expr", - [ - ("Funsor", "Reduce[ops.AssociativeOp, Funsor, frozenset]"), - ("Reduce", "Reduce[ops.AssociativeOp, Funsor, frozenset]"), - ( - "Reduce[ops.Op, Funsor, frozenset]", - "Reduce[ops.AssociativeOp, Funsor, frozenset]", - ), - ( - "Reduce[ops.AssociativeOp, Reduce[ops.AssociativeOp, Funsor, frozenset], frozenset]", - "Reduce[ops.Op, Variable, frozenset]", - ), - ( - "Reduce[ops.AssociativeOp, Reduce[ops.AssociativeOp, Funsor, frozenset], frozenset]", - "Reduce[ops.AssociativeOp, Reduce[ops.AddOp, Funsor, frozenset], frozenset]", - ), - ("Stack", "Stack[str, typing.Tuple[Number, Number, Number]]"), - ("Stack[str, tuple]", "Stack[str, typing.Tuple[Number, Number, Number]]"), - ( - "Stack[str, typing.Tuple[Number, Number]]", - "Stack[str, typing.Tuple[Number, Reduce]]", - ), - ( - "Stack[str, typing.Tuple[Number, Reduce]]", - "Stack[str, typing.Tuple[Number, Number]]", - ), - # Unions - ("Funsor", "Reduce[ops.AssociativeOp, (Number, Funsor), frozenset]"), - ( - "Reduce[ops.Op, Funsor, frozenset]", - "Reduce[ops.AssociativeOp, (Number, Stack), frozenset]", - ), - ( - "Reduce[(ops.Op, ops.AssociativeOp), Stack, frozenset]", - "Reduce[ops.AssociativeOp, (Stack[str, tuple], " - "Reduce[ops.AssociativeOp, (Cat, Stack), frozenset]), frozenset]", - ), - ], -) -def test_not_parametric_subclass(subcls_expr, cls_expr): - subcls = eval(subcls_expr) - cls = eval(cls_expr) - print(subcls.classname) - print(cls.classname) - assert issubclass(cls, (Funsor, Reduce)) and not issubclass( - subcls, typing.Tuple - ) # appease flake8 - assert not issubclass(subcls, cls) - - @pytest.mark.parametrize( "start,stop", [(1, 3), (0, 1), (3, 7), (4, 8), (0, 2), (0, 10), (1, 2), (1, 10), (2, 10)], diff --git a/test/test_typing.py b/test/test_typing.py new file mode 100644 index 000000000..4a2393503 --- /dev/null +++ b/test/test_typing.py @@ -0,0 +1,357 @@ +# Copyright Contributors to the Pyro project. +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any, FrozenSet, Optional, Tuple, Union + +import pytest +from multipledispatch import dispatch + +from funsor.ops import AssociativeOp, Op +from funsor.registry import PartialDispatcher +from funsor.terms import Cat, Funsor, Number, Reduce, Stack, Variable +from funsor.typing import ( + GenericTypeMeta, + Variadic, + deep_isinstance, + deep_issubclass, + deep_type, + get_args, + get_origin, + get_type_hints, + typing_wrap, +) + + +def test_deep_issubclass_generic_identity(): + assert deep_issubclass(Reduce, Reduce) + assert deep_issubclass( + Reduce[AssociativeOp, Funsor, frozenset], + Reduce[AssociativeOp, Funsor, frozenset], + ) + assert deep_issubclass(Tuple, Tuple) + + +def test_deep_issubclass_generic_empty(): + assert deep_issubclass(Reduce[AssociativeOp, Funsor, frozenset], Funsor) + assert deep_issubclass(Reduce[AssociativeOp, Funsor, frozenset], Reduce) + assert not deep_issubclass(Funsor, Reduce[AssociativeOp, Funsor, frozenset]) + assert not deep_issubclass(Reduce, Reduce[AssociativeOp, Funsor, frozenset]) + + +def test_deep_issubclass_generic_neither(): + assert not deep_issubclass( + Reduce[AssociativeOp, Reduce[AssociativeOp, Funsor, frozenset], frozenset], + Reduce[Op, Variable, frozenset], + ) + assert not deep_issubclass( + Reduce[Op, Variable, frozenset], + Reduce[AssociativeOp, Reduce[AssociativeOp, Funsor, frozenset], frozenset], + ) + + assert not deep_issubclass( + Stack[str, Tuple[Number, Number]], + Stack[str, Tuple[Number, Reduce]], + ) + assert not deep_issubclass( + Stack[str, Tuple[Number, Reduce]], + Stack[str, Tuple[Number, Number]], + ) + + +def test_deep_issubclass_generic_tuple_internal(): + assert deep_issubclass(Stack[str, Tuple[Number, Number, Number]], Stack) + assert deep_issubclass(Stack[str, Tuple[Number, Number, Number]], Stack[str, tuple]) + assert not deep_issubclass(Stack, Stack[str, Tuple[Number, Number, Number]]) + assert not deep_issubclass( + Stack[str, tuple], Stack[str, Tuple[Number, Number, Number]] + ) + assert not deep_issubclass( + Stack[str, Tuple[Number, Number]], + Stack[str, Tuple[Number, Reduce]], + ) + + +def test_deep_issubclass_generic_union_internal(): + + assert deep_issubclass( + Reduce[AssociativeOp, Union[Number, Funsor], frozenset], Funsor + ) + assert not deep_issubclass( + Funsor, Reduce[AssociativeOp, Union[Number, Funsor], frozenset] + ) + + assert deep_issubclass( + Reduce[ + AssociativeOp, + Union[Number, Stack[str, Tuple[Number, Number]]], + frozenset, + ], + Funsor, + ) + assert deep_issubclass( + Reduce[AssociativeOp, Union[Number, Stack], frozenset], + Reduce[Op, Funsor, frozenset], + ) + + assert deep_issubclass( + Reduce[AssociativeOp, Union[Number, Stack], frozenset], + Reduce[AssociativeOp, Funsor, frozenset], + ) + assert not deep_issubclass( + Reduce[AssociativeOp, Funsor, frozenset], + Reduce[AssociativeOp, Union[Number, Stack], frozenset], + ) + assert not deep_issubclass( + Reduce[Op, Funsor, frozenset], + Reduce[AssociativeOp, Union[Number, Stack], frozenset], + ) + + +def test_deep_issubclass_generic_union_internal_multiple(): + assert not deep_issubclass( + Reduce[Union[Op, AssociativeOp], Stack, frozenset], + Reduce[ + AssociativeOp, + Union[ + Stack[str, tuple], + Reduce[AssociativeOp, Union[Cat, Stack], frozenset], + ], + frozenset, + ], + ) + + assert not deep_issubclass( + Reduce[ + AssociativeOp, + Union[Stack, Reduce[AssociativeOp, Union[Number, Stack], frozenset]], + frozenset, + ], + Reduce[Union[Op, AssociativeOp], Stack, frozenset], + ) + + +def test_deep_issubclass_tuple_variadic(): + + assert deep_issubclass(Tuple[int, ...], Tuple) + assert deep_issubclass(Tuple[int], Tuple[int, ...]) + assert deep_issubclass(Tuple[int, int], Tuple[int, ...]) + + assert not deep_issubclass(Tuple[int, ...], Tuple[int]) + assert not deep_issubclass(Tuple[int, ...], Tuple[int, int]) + + assert deep_issubclass(Tuple[Reduce, ...], Tuple[Funsor, ...]) + assert not deep_issubclass(Tuple[Funsor, ...], Tuple[Reduce, ...]) + + assert deep_issubclass(Tuple[Reduce], Tuple[Funsor, ...]) + assert not deep_issubclass(Tuple[Funsor], Tuple[Reduce, ...]) + + assert not deep_issubclass(Tuple[str], Tuple[int, ...]) + assert not deep_issubclass(Tuple[int, str], Tuple[int, ...]) + + assert deep_issubclass(Tuple[Tuple[int, str]], Tuple[Tuple, ...]) + assert deep_issubclass(Tuple[Tuple[int, str], Tuple[int, str]], Tuple[Tuple, ...]) + assert deep_issubclass(Tuple[Tuple[int, str]], Tuple[Tuple[int, str], ...]) + assert deep_issubclass(Tuple[Tuple[int, str], ...], Tuple[Tuple[int, str], ...]) + assert deep_issubclass( + Tuple[Tuple[int, str], Tuple[int, str]], Tuple[Tuple[int, str], ...] + ) + + +def test_deep_type_tuple(): + + x1 = (1, 1.5, "a") + expected_type1 = Tuple[int, float, str] + assert deep_type(x1) is expected_type1 + assert deep_isinstance(x1, expected_type1) + + x2 = (1, (2, 3)) + expected_type2 = Tuple[int, Tuple[int, int]] + assert deep_type(x2) is expected_type2 + assert deep_isinstance(x2, expected_type2) + + +def test_deep_type_frozenset(): + + x1 = frozenset(["a", "b"]) + expected_type1 = FrozenSet[str] + assert deep_type(x1) is expected_type1 + assert deep_isinstance(x1, expected_type1) + + with pytest.raises(NotImplementedError): + x2 = frozenset(["a", 1]) + deep_type(x2) + + +def test_generic_type_cons_hash(): + class A(metaclass=GenericTypeMeta): + pass + + class B(metaclass=GenericTypeMeta): + pass + + assert A[int] is A[int] + assert A[float] is not A[int] + assert B[int] is not A[int] + assert B[A[int], int] is B[A[int], int] + + assert FrozenSet[int] is FrozenSet[int] + assert FrozenSet[B[int]] is FrozenSet[B[int]] + + assert Tuple[B[int, int], ...] is Tuple[B[int, int], ...] + + assert Union[B[int]] is Union[B[int]] + assert Union[B[int], B[int]] is Union[B[int]] + + +def test_get_origin(): + + assert get_origin(Any) is Any + + assert get_origin(Tuple[int]) in (tuple, Tuple) + assert get_origin(FrozenSet[int]) in (frozenset, FrozenSet) + + assert get_origin(Union[int, int, str]) is Union + + assert get_origin(Reduce[AssociativeOp, Funsor, frozenset]) is Reduce + assert get_origin(Reduce) is Reduce + + +def test_get_args(): + assert not get_args(Any) + + assert get_args(Tuple[int]) == (int,) + assert get_args(Tuple[int, ...]) == (int, ...) + assert not get_args(Tuple) + + assert get_args(FrozenSet[int]) == (int,) + + assert int in get_args(Optional[int]) + + assert get_args(Union[int]) == () + assert get_args(Union[int, str]) == (int, str) + assert get_args(Union[int, int, str]) == (int, str) + + assert get_args(Reduce[AssociativeOp, Funsor, frozenset]) == ( + AssociativeOp, + Funsor, + frozenset, + ) + assert not get_args(Reduce) + + +def test_get_type_hints(): + def f(a: Tuple[int, ...], b: Reduce[AssociativeOp, Funsor, frozenset]) -> int: + return 0 + + hints = get_type_hints(f) + assert hints == { + "a": Tuple[int, ...], + "b": Reduce[AssociativeOp, Funsor, frozenset], + "return": int, + } + + hints.pop("return") + assert "return" in get_type_hints(f) + + +def test_variadic_dispatch_basic(): + @dispatch(Variadic[object]) + def f(*args): + return 1 + + @dispatch(int, int) + def f(a, b): + return 2 + + @dispatch(Variadic[int]) + def f(*args): + return 3 + + @dispatch(typing_wrap(Tuple), typing_wrap(Tuple)) + def f(a, b): + return 4 + + @dispatch(Variadic[Tuple]) + def f(*args): + return 5 + + assert f(1.5) == 1 + assert f(1.5, 1) == 1 + + assert f(1, 1) == 2 + assert f(1) == 3 + assert f(1, 2, 3) == 3 + + assert f((1, 1), (1, 1)) == 4 + assert f((1, 2)) == 5 + assert f((1, 2), (3, 4), (5, 6)) == 5 + + +def test_dispatch_typing(): + + f = PartialDispatcher(lambda *args: 1) + + @f.register() + def f2(a: int, b: int) -> int: + return 2 + + @f.register() + def f3(a: Tuple[int, int], b: Tuple[int, int]) -> int: + return 3 + + @f.register() + def f4(a: Tuple[int, ...], b: Tuple[int, int]) -> int: + return 4 + + @f.register() + def f5(a: Tuple[int, float], b: Tuple[int, float]) -> int: + return 5 + + assert f(1.5) == 1 + assert f(1.5, 1) == 1 + + assert f(1, 1) == 2 + + assert f((1, 1), (1, 1)) == 3 + assert f((1, 2)) == 1 + assert f((1, 2, 3), (4, 5)) == 4 + + assert f((1, 1.5), (2, 2.5)) == 5 + + +def test_variadic_dispatch_typing(): + + f = PartialDispatcher(lambda *args: 1) + + @f.register() + def _(a: int, b: int) -> int: + return 2 + + @f.register([int]) + def _(*args): + return 3 + + @f.register() + def _(a: Tuple[int, int], b: Tuple[int, int]) -> int: + return 4 + + @f.register([Tuple[int, int]]) # list syntax for variadic + def _(*args): + return 5 + + @f.register() + def _(a: Tuple[int, float], b: Tuple[int, float]) -> int: + return 6 + + assert f(1.5) == 1 + assert f(1.5, 1) == 1 + + assert f(1, 1) == 2 + assert f(1) == 3 + assert f(1, 2, 3) == 3 + + assert f((1, 1), (1, 1)) == 4 + assert f((1, 2)) == 5 + assert f((1, 2), (3, 4), (5, 6)) == 5 + + assert f((1, 1.5), (2, 2.5)) == 6