From 9d653cb0a93d90ddc802b0f9966ef9435b950dc9 Mon Sep 17 00:00:00 2001 From: Eli Date: Mon, 18 Jan 2021 22:14:12 -0500 Subject: [PATCH 01/66] Extract parametric type metaclass from FunsorMeta and domains --- funsor/domains.py | 115 ++++++++++++---------------------------------- funsor/terms.py | 85 ++-------------------------------- funsor/util.py | 92 +++++++++++++++++++++++++++++++++++++ test/test_cnf.py | 2 +- 4 files changed, 126 insertions(+), 168 deletions(-) diff --git a/funsor/domains.py b/funsor/domains.py index 87f2ce644..22568700f 100644 --- a/funsor/domains.py +++ b/funsor/domains.py @@ -6,20 +6,19 @@ import operator import warnings from functools import reduce -from weakref import WeakValueDictionary import funsor.ops as ops -from funsor.util import broadcast_shape, get_backend, get_tracing_state, quote +from funsor.util import GenericTypeMeta, broadcast_shape, get_backend, get_tracing_state, quote -Domain = type + +class Domain(GenericTypeMeta): + pass class ArrayType(Domain): """ Base class of array-like domains. """ - _type_cache = WeakValueDictionary() - def __getitem__(cls, dtype_shape): dtype, shape = dtype_shape assert dtype is not None @@ -32,23 +31,7 @@ def __getitem__(cls, dtype_shape): if shape is not None: shape = tuple(map(int, shape)) - assert cls.dtype in (None, dtype) - assert cls.shape in (None, shape) - key = dtype, shape - result = ArrayType._type_cache.get(key, None) - if result is None: - if dtype == "real": - assert all(isinstance(size, int) and size >= 0 for size in shape) - name = "Reals[{}]".format(",".join(map(str, shape))) if shape else "Real" - result = RealsType(name, (), {"shape": shape}) - elif isinstance(dtype, int): - assert dtype >= 0 - name = "Bint[{}, {}]".format(dtype, ",".join(map(str, shape))) - result = BintType(name, (), {"dtype": dtype, "shape": shape}) - else: - raise ValueError("invalid dtype: {}".format(dtype)) - ArrayType._type_cache[key] = result - return result + return super().__getitem__((dtype, shape)) def __subclasscheck__(cls, subcls): if not isinstance(subcls, ArrayType): @@ -59,34 +42,18 @@ def __subclasscheck__(cls, subcls): return False return True - def __repr__(cls): - return cls.__name__ + @property + def dtype(cls): + return cls.__args__[0] - def __str__(cls): - return cls.__name__ + @property + def shape(cls): + return cls.__args__[1] @property def num_elements(cls): return reduce(operator.mul, cls.shape, 1) - -class BintType(ArrayType): - def __getitem__(cls, size_shape): - if isinstance(size_shape, tuple): - size, shape = size_shape[0], size_shape[1:] - else: - size, shape = size_shape, () - return super().__getitem__((size, shape)) - - def __subclasscheck__(cls, subcls): - if not isinstance(subcls, BintType): - return False - if cls.dtype not in (None, subcls.dtype): - return False - if cls.shape not in (None, subcls.shape): - return False - return True - @property def size(cls): return cls.dtype @@ -96,27 +63,25 @@ def __iter__(cls): return (Number(i, cls.size) for i in range(cls.size)) -class RealsType(ArrayType): - dtype = "real" +class BintType(ArrayType): + def __getitem__(cls, size_shape): + if isinstance(size_shape, tuple): + size, shape = size_shape[0], size_shape[1:] + else: + size, shape = size_shape, () + return Array.__getitem__((size, shape)) + +class RealsType(ArrayType): def __getitem__(cls, shape): if not isinstance(shape, tuple): shape = (shape,) - return super().__getitem__(("real", shape)) - - def __subclasscheck__(cls, subcls): - if not isinstance(subcls, RealsType): - return False - if cls.dtype not in (None, subcls.dtype): - return False - if cls.shape not in (None, subcls.shape): - return False - return True + return Array.__getitem__(("real", shape)) def _pickle_array(cls): - if cls in (Array, Bint, Real, Reals): - return cls.__name__ + if cls in (Array, Bint, Reals): + return repr(cls) return operator.getitem, (Array, (cls.dtype, cls.shape)) @@ -132,22 +97,20 @@ class Array(metaclass=ArrayType): Arary["real", (3, 3)] = Reals[3, 3] Array["real", ()] = Real """ - dtype = None - shape = None + pass -class Bint(metaclass=BintType): +class Bint(Array, metaclass=BintType): """ Factory for bounded integer types:: Bint[5] # integers ranging in {0,1,2,3,4} Bint[2, 3, 3] # 3x3 matrices with entries in {0,1} """ - dtype = None - shape = None + pass -class Reals(metaclass=RealsType): +class Reals(Array, metaclass=RealsType): """ Type of a real-valued array with known shape:: @@ -155,7 +118,7 @@ class Reals(metaclass=RealsType): Reals[8] # vector of length 8 Reals[3, 3] # 3x3 matrix """ - shape = None + pass Real = Reals[()] @@ -176,26 +139,6 @@ def bint(size): class ProductDomain(Domain): - - _type_cache = WeakValueDictionary() - - def __getitem__(cls, arg_domains): - try: - return ProductDomain._type_cache[arg_domains] - except KeyError: - assert isinstance(arg_domains, tuple) - assert all(isinstance(arg_domain, Domain) for arg_domain in arg_domains) - subcls = type("Product_", (Product,), {"__args__": arg_domains}) - ProductDomain._type_cache[arg_domains] = subcls - return subcls - - def __repr__(cls): - return "Product[{}]".format(", ".join(map(repr, cls.__args__))) - - @property - def __origin__(cls): - return Product - @property def shape(cls): return (len(cls.__args__),) @@ -203,7 +146,7 @@ def shape(cls): class Product(tuple, metaclass=ProductDomain): """like typing.Tuple, but works with issubclass""" - __args__ = NotImplemented + pass @quote.register(BintType) diff --git a/funsor/terms.py b/funsor/terms.py index 7cf6d1f51..45007c2a2 100644 --- a/funsor/terms.py +++ b/funsor/terms.py @@ -13,14 +13,14 @@ from weakref import WeakValueDictionary from multipledispatch import dispatch -from multipledispatch.variadic import Variadic, isvariadic +from multipledispatch.variadic import Variadic import funsor.interpreter as interpreter import funsor.ops as ops from funsor.domains import Array, Bint, Domain, Product, Real, find_domain from funsor.interpreter import PatternMissingError, dispatched_interpretation, interpret from funsor.ops import AssociativeOp, GetitemOp, Op -from funsor.util import getargspec, get_backend, lazy_property, pretty, quote +from funsor.util import GenericTypeMeta, getargspec, get_backend, lazy_property, pretty, quote def substitute(expr, subs): @@ -182,7 +182,7 @@ def moment_matching(cls, *args): interpreter.set_interpretation(eager) # Use eager interpretation by default. -class FunsorMeta(type): +class FunsorMeta(GenericTypeMeta): """ Metaclass for Funsors to perform four independent tasks: @@ -206,15 +206,9 @@ class FunsorMeta(type): """ def __init__(cls, name, bases, dct): super(FunsorMeta, cls).__init__(name, bases, dct) - if not hasattr(cls, "__args__"): - cls.__args__ = () - if cls.__args__: - base, = bases - cls.__origin__ = base - else: + if not cls.__args__: cls._ast_fields = getargspec(cls.__init__)[0][1:] cls._cons_cache = WeakValueDictionary() - cls._type_cache = WeakValueDictionary() def __call__(cls, *args, **kwargs): if cls.__args__: @@ -230,77 +224,6 @@ def __call__(cls, *args, **kwargs): return interpret(cls, *args) - def __getitem__(cls, arg_types): - if not isinstance(arg_types, tuple): - arg_types = (arg_types,) - assert not any(isvariadic(arg_type) for arg_type in arg_types), "nested variadic types not supported" - # switch tuple to typing.Tuple - arg_types = tuple(typing.Tuple if arg_type is tuple else arg_type for arg_type in arg_types) - if arg_types not in cls._type_cache: - assert not cls.__args__, "cannot subscript a subscripted type {}".format(cls) - assert len(arg_types) == len(cls._ast_fields), "must provide types for all params" - new_dct = cls.__dict__.copy() - new_dct.update({"__args__": arg_types}) - # type(cls) to handle FunsorMeta subclasses - cls._type_cache[arg_types] = type(cls)(cls.__name__, (cls,), new_dct) - return cls._type_cache[arg_types] - - def __subclasscheck__(cls, subcls): # issubclass(subcls, cls) - if cls is subcls: - return True - if not isinstance(subcls, FunsorMeta): - return super(FunsorMeta, getattr(cls, "__origin__", cls)).__subclasscheck__(subcls) - - cls_origin = getattr(cls, "__origin__", cls) - subcls_origin = getattr(subcls, "__origin__", subcls) - if not super(FunsorMeta, cls_origin).__subclasscheck__(subcls_origin): - return False - - if cls.__args__: - if not subcls.__args__: - return False - if len(cls.__args__) != len(subcls.__args__): - return False - for subcls_param, param in zip(subcls.__args__, cls.__args__): - if not _issubclass_tuple(subcls_param, param): - return False - return True - - @lazy_property - def classname(cls): - return cls.__name__ + "[{}]".format(", ".join( - str(getattr(t, "classname", t)) # Tuple doesn't have __name__ - for t in cls.__args__)) - - -def _issubclass_tuple(subcls, cls): - """ - utility for pattern matching with tuple subexpressions - """ - # so much boilerplate... - cls_is_union = hasattr(cls, "__origin__") and (cls.__origin__ or cls) is typing.Union - if isinstance(cls, tuple) or cls_is_union: - return any(_issubclass_tuple(subcls, option) - for option in (getattr(cls, "__args__", []) if cls_is_union else cls)) - - subcls_is_union = hasattr(subcls, "__origin__") and (subcls.__origin__ or subcls) is typing.Union - if isinstance(subcls, tuple) or subcls_is_union: - return any(_issubclass_tuple(option, cls) - for option in (getattr(subcls, "__args__", []) if subcls_is_union else subcls)) - - subcls_is_tuple = hasattr(subcls, "__origin__") and (subcls.__origin__ or subcls) in (tuple, typing.Tuple) - cls_is_tuple = hasattr(cls, "__origin__") and (cls.__origin__ or cls) in (tuple, typing.Tuple) - if subcls_is_tuple != cls_is_tuple: - return False - if not cls_is_tuple: - return issubclass(subcls, cls) - if not cls.__args__: - return True - if not subcls.__args__ or len(subcls.__args__) != len(cls.__args__): - return False - - return all(_issubclass_tuple(a, b) for a, b in zip(subcls.__args__, cls.__args__)) - def _convert_reduced_vars(reduced_vars, inputs): """ diff --git a/funsor/util.py b/funsor/util.py index d0692ebd8..004f5d596 100644 --- a/funsor/util.py +++ b/funsor/util.py @@ -4,8 +4,12 @@ import functools import inspect import re +import typing +import weakref import numpy as np +from multipledispatch.variadic import isvariadic + _FUNSOR_BACKEND = "numpy" _JAX_LOADED = False @@ -230,3 +234,91 @@ def decorator(fn): setattr(cls, name_, fn) return fn return decorator + + +class GenericTypeMeta(type): + """ + Metaclass to support subtyping with parameters for pattern matching, e.g. Number[int, int]. + """ + def __init__(cls, name, bases, dct): + super().__init__(name, bases, dct) + if not hasattr(cls, "__args__"): + cls.__args__ = () + if cls.__args__: + base, = bases + cls.__origin__ = base + else: + cls._type_cache = weakref.WeakValueDictionary() + + def __getitem__(cls, arg_types): + if not isinstance(arg_types, tuple): + arg_types = (arg_types,) + assert not any(isvariadic(arg_type) for arg_type in arg_types), "nested variadic types not supported" + # switch tuple to typing.Tuple + arg_types = tuple(typing.Tuple if arg_type is tuple else arg_type for arg_type in arg_types) + if arg_types not in cls._type_cache: + assert not cls.__args__, "cannot subscript a subscripted type {}".format(cls) + new_dct = cls.__dict__.copy() + new_dct.update({"__args__": arg_types}) + # type(cls) to handle GenericTypeMeta subclasses + cls._type_cache[arg_types] = type(cls)(cls.__name__, (cls,), new_dct) + return cls._type_cache[arg_types] + + def __subclasscheck__(cls, subcls): # issubclass(subcls, cls) + if cls is subcls: + return True + if not isinstance(subcls, GenericTypeMeta): + return super(GenericTypeMeta, getattr(cls, "__origin__", cls)).__subclasscheck__(subcls) + + cls_origin = getattr(cls, "__origin__", cls) + subcls_origin = getattr(subcls, "__origin__", subcls) + if not super(GenericTypeMeta, cls_origin).__subclasscheck__(subcls_origin): + return False + + if cls.__args__: + if not subcls.__args__: + return False + if len(cls.__args__) != len(subcls.__args__): + return False + for subcls_param, param in zip(subcls.__args__, cls.__args__): + if not _issubclass_tuple(subcls_param, param): + return False + return True + + def __repr__(cls): + return cls.__name__ + ( + "" if not cls.__args__ else + "[{}]".format(", ".join(repr(t) for t in cls.__args__))) + + @lazy_property + def classname(cls): + return repr(cls) + + +def _issubclass_tuple(subcls, cls): + """ + utility for structural subtype checking with tuple subexpressions + """ + # so much boilerplate... + cls_is_union = hasattr(cls, "__origin__") and (cls.__origin__ or cls) is typing.Union + if isinstance(cls, tuple) or cls_is_union: + return any(_issubclass_tuple(subcls, option) + for option in (getattr(cls, "__args__", []) if cls_is_union else cls)) + + subcls_is_union = hasattr(subcls, "__origin__") and (subcls.__origin__ or subcls) is typing.Union + if isinstance(subcls, tuple) or subcls_is_union: + return any(_issubclass_tuple(option, cls) + for option in (getattr(subcls, "__args__", []) if subcls_is_union else subcls)) + + subcls_is_tuple = hasattr(subcls, "__origin__") and (subcls.__origin__ or subcls) in (tuple, typing.Tuple) + cls_is_tuple = hasattr(cls, "__origin__") and (cls.__origin__ or cls) in (tuple, typing.Tuple) + if subcls_is_tuple != cls_is_tuple: + return False + if not cls_is_tuple: + return issubclass(subcls, cls) + if not cls.__args__: + return True + if not subcls.__args__ or len(subcls.__args__) != len(cls.__args__): + return False + + return all(_issubclass_tuple(a, b) for a, b in zip(subcls.__args__, cls.__args__)) diff --git a/test/test_cnf.py b/test/test_cnf.py index 2e1c6494d..a5e184955 100644 --- a/test/test_cnf.py +++ b/test/test_cnf.py @@ -9,7 +9,7 @@ from funsor import ops from funsor.cnf import Contraction, BACKEND_TO_EINSUM_BACKEND, BACKEND_TO_LOGSUMEXP_BACKEND -from funsor.domains import Bint, Bint # noqa F403 +from funsor.domains import Array, Bint # noqa F403 from funsor.domains import Reals from funsor.einsum import einsum, naive_plated_einsum from funsor.interpreter import interpretation, reinterpret From 0177b072faadd9822a38651c602256dec087d439 Mon Sep 17 00:00:00 2001 From: Eli Date: Tue, 19 Jan 2021 00:23:20 -0500 Subject: [PATCH 02/66] restore assertion in FunsorMeta.__getitem__ --- funsor/terms.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/funsor/terms.py b/funsor/terms.py index 45007c2a2..019b1d7f5 100644 --- a/funsor/terms.py +++ b/funsor/terms.py @@ -210,6 +210,13 @@ def __init__(cls, name, bases, dct): cls._ast_fields = getargspec(cls.__init__)[0][1:] cls._cons_cache = WeakValueDictionary() + def __getitem__(cls, arg_types): + if not isinstance(arg_types, tuple): + arg_types = (arg_types,) + assert len(arg_types) == len(cls._ast_fields), \ + "Must provide exactly one type per subexpression" + return super().__getitem__(arg_types) + def __call__(cls, *args, **kwargs): if cls.__args__: cls = cls.__origin__ From 14622777769b0b71ab3eadd80745f0d4e02d9b25 Mon Sep 17 00:00:00 2001 From: Eli Date: Tue, 19 Jan 2021 18:44:41 -0500 Subject: [PATCH 03/66] switch to pytypes multipledispatch backend and use pytypes in generic type --- funsor/ops/op.py | 7 +++- funsor/registry.py | 9 ++++- funsor/terms.py | 3 +- funsor/util.py | 82 +++++++++++++++++++--------------------------- 4 files changed, 50 insertions(+), 51 deletions(-) diff --git a/funsor/ops/op.py b/funsor/ops/op.py index 700c70eae..fe0a9a16f 100644 --- a/funsor/ops/op.py +++ b/funsor/ops/op.py @@ -7,6 +7,8 @@ from multipledispatch import Dispatcher +from funsor.util import _type_to_typing + class CachedOpMeta(type): """ @@ -39,7 +41,10 @@ def __init__(self, fn, *, name=None): # register as default operation for nargs in (1, 2): default_signature = (object,) * nargs - self.add(default_signature, fn) + self.register(*default_signature)(fn) + + def register(self, *types): + return super().register(*tuple(map(_type_to_typing, types))) # Register all existing patterns. for supercls in reversed(inspect.getmro(type(self))): diff --git a/funsor/registry.py b/funsor/registry.py index a46d7052b..9401ceabf 100644 --- a/funsor/registry.py +++ b/funsor/registry.py @@ -1,11 +1,16 @@ # Copyright Contributors to the Pyro project. # SPDX-License-Identifier: Apache-2.0 +import functools +import typing from collections import defaultdict +import pytypes from multipledispatch import Dispatcher from multipledispatch.conflict import supercedes +from funsor.util import _type_to_typing + class PartialDispatcher(Dispatcher): """ @@ -16,6 +21,7 @@ def partial_call(self, *args): Likde :meth:`__call__` but avoids calling ``func()``. """ types = tuple(map(type, args)) + types = tuple(map(_type_to_typing, types)) try: func = self._cache[types] except KeyError: @@ -47,10 +53,11 @@ def __init__(self, default=None): self.registry = defaultdict(lambda: PartialDispatcher('f')) def register(self, key, *types): + types = tuple(map(_type_to_typing, types)) key = getattr(key, "__origin__", key) register = self.registry[key].register if self.default: - objects = (object,) * len(types) + objects = (typing.Any,) * len(types) try: if objects != types and supercedes(types, objects): register(*objects)(self.default) diff --git a/funsor/terms.py b/funsor/terms.py index 019b1d7f5..33588ea77 100644 --- a/funsor/terms.py +++ b/funsor/terms.py @@ -20,7 +20,7 @@ from funsor.domains import Array, Bint, Domain, Product, Real, find_domain from funsor.interpreter import PatternMissingError, dispatched_interpretation, interpret from funsor.ops import AssociativeOp, GetitemOp, Op -from funsor.util import GenericTypeMeta, getargspec, get_backend, lazy_property, pretty, quote +from funsor.util import GenericTypeMeta, getargspec, get_backend, lazy_property, pretty, quote, _type_to_typing def substitute(expr, subs): @@ -215,6 +215,7 @@ def __getitem__(cls, arg_types): arg_types = (arg_types,) assert len(arg_types) == len(cls._ast_fields), \ "Must provide exactly one type per subexpression" + arg_types = tuple(map(_type_to_typing, arg_types)) return super().__getitem__(arg_types) def __call__(cls, *args, **kwargs): diff --git a/funsor/util.py b/funsor/util.py index 004f5d596..cd5f60eef 100644 --- a/funsor/util.py +++ b/funsor/util.py @@ -5,8 +5,10 @@ import inspect import re import typing +import typing_extensions import weakref +import pytypes import numpy as np from multipledispatch.variadic import isvariadic @@ -236,6 +238,26 @@ def decorator(fn): return decorator +def _type_to_typing(tp): + if tp is object: + tp = typing.Any + if isinstance(tp, tuple): + tp = typing.Union[tuple(map(_type_to_typing, tp))] + return tp + + +def get_origin(tp): + if isinstance(tp, GenericTypeMeta): + return getattr(tp, "__origin__", tp) + return typing_extensions.get_origin(tp) + + +def get_args(tp): + if isinstance(tp, GenericTypeMeta): + return getattr(tp, "__args__", tp) + return typing_extensions.get_args(tp) + + class GenericTypeMeta(type): """ Metaclass to support subtyping with parameters for pattern matching, e.g. Number[int, int]. @@ -254,10 +276,8 @@ def __getitem__(cls, arg_types): if not isinstance(arg_types, tuple): arg_types = (arg_types,) assert not any(isvariadic(arg_type) for arg_type in arg_types), "nested variadic types not supported" - # switch tuple to typing.Tuple - arg_types = tuple(typing.Tuple if arg_type is tuple else arg_type for arg_type in arg_types) if arg_types not in cls._type_cache: - assert not cls.__args__, "cannot subscript a subscripted type {}".format(cls) + assert not get_args(cls), "cannot subscript a subscripted type {}".format(cls) new_dct = cls.__dict__.copy() new_dct.update({"__args__": arg_types}) # type(cls) to handle GenericTypeMeta subclasses @@ -267,58 +287,24 @@ def __getitem__(cls, arg_types): def __subclasscheck__(cls, subcls): # issubclass(subcls, cls) if cls is subcls: return True + if not isinstance(subcls, GenericTypeMeta): - return super(GenericTypeMeta, getattr(cls, "__origin__", cls)).__subclasscheck__(subcls) + return super(GenericTypeMeta, get_origin(cls)).__subclasscheck__(subcls) - cls_origin = getattr(cls, "__origin__", cls) - subcls_origin = getattr(subcls, "__origin__", subcls) - if not super(GenericTypeMeta, cls_origin).__subclasscheck__(subcls_origin): + if not super(GenericTypeMeta, get_origin(cls)).__subclasscheck__(get_origin(subcls)): return False - if cls.__args__: - if not subcls.__args__: - return False - if len(cls.__args__) != len(subcls.__args__): - return False - for subcls_param, param in zip(subcls.__args__, cls.__args__): - if not _issubclass_tuple(subcls_param, param): - return False - return True + if len(get_args(cls)) != len(get_args(subcls)): + return len(get_args(cls)) == 0 + + return all(pytypes.is_subtype(_type_to_typing(ps), _type_to_typing(pc)) + for ps, pc in zip(get_args(subcls), get_args(cls))) def __repr__(cls): - return cls.__name__ + ( - "" if not cls.__args__ else - "[{}]".format(", ".join(repr(t) for t in cls.__args__))) + return get_origin(cls).__name__ + ( + "" if not get_args(cls) else + "[{}]".format(", ".join(repr(t) for t in get_args(cls)))) @lazy_property def classname(cls): return repr(cls) - - -def _issubclass_tuple(subcls, cls): - """ - utility for structural subtype checking with tuple subexpressions - """ - # so much boilerplate... - cls_is_union = hasattr(cls, "__origin__") and (cls.__origin__ or cls) is typing.Union - if isinstance(cls, tuple) or cls_is_union: - return any(_issubclass_tuple(subcls, option) - for option in (getattr(cls, "__args__", []) if cls_is_union else cls)) - - subcls_is_union = hasattr(subcls, "__origin__") and (subcls.__origin__ or subcls) is typing.Union - if isinstance(subcls, tuple) or subcls_is_union: - return any(_issubclass_tuple(option, cls) - for option in (getattr(subcls, "__args__", []) if subcls_is_union else subcls)) - - subcls_is_tuple = hasattr(subcls, "__origin__") and (subcls.__origin__ or subcls) in (tuple, typing.Tuple) - cls_is_tuple = hasattr(cls, "__origin__") and (cls.__origin__ or cls) in (tuple, typing.Tuple) - if subcls_is_tuple != cls_is_tuple: - return False - if not cls_is_tuple: - return issubclass(subcls, cls) - if not cls.__args__: - return True - if not subcls.__args__ or len(subcls.__args__) != len(cls.__args__): - return False - - return all(_issubclass_tuple(a, b) for a, b in zip(subcls.__args__, cls.__args__)) From 2d6ee23eb4c1ac198810e2f4531989f352273ae3 Mon Sep 17 00:00:00 2001 From: Eli Date: Tue, 19 Jan 2021 22:46:06 -0500 Subject: [PATCH 04/66] pin dependencies using new pip syntax --- setup.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/setup.py b/setup.py index e8148b577..d59520598 100644 --- a/setup.py +++ b/setup.py @@ -33,11 +33,18 @@ author='Uber AI Labs', author_email='fritzo@uber.com', python_requires=">=3.6", + dependency_links=[ + # pin pytypes to master + 'git+https://github.com/Stewori/pytypes.git@master#egg=pytypes-0a0', + # use a fork of multipledispatch that depends on pytypes + 'git+https://github.com/eb8680/multipledispatch.git@pytypes-master#egg=multipledispatch.git-0.6.0a0', + ], install_requires=[ 'makefun', 'multipledispatch', 'numpy>=1.7', 'opt_einsum>=2.3.2', + 'pytypes', ], extras_require={ 'torch': [ From 5b98eba1e722195881839960c1044dec4bc32db8 Mon Sep 17 00:00:00 2001 From: Eli Date: Wed, 20 Jan 2021 00:25:09 -0500 Subject: [PATCH 05/66] fix dependencies --- setup.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/setup.py b/setup.py index d59520598..cdf3c0ad7 100644 --- a/setup.py +++ b/setup.py @@ -33,18 +33,14 @@ author='Uber AI Labs', author_email='fritzo@uber.com', python_requires=">=3.6", - dependency_links=[ - # pin pytypes to master - 'git+https://github.com/Stewori/pytypes.git@master#egg=pytypes-0a0', - # use a fork of multipledispatch that depends on pytypes - 'git+https://github.com/eb8680/multipledispatch.git@pytypes-master#egg=multipledispatch.git-0.6.0a0', - ], install_requires=[ 'makefun', - 'multipledispatch', + # use a fork of multipledispatch that depends on pytypes + 'git+https://github.com/eb8680/multipledispatch.git@pytypes-master#egg=multipledispatch', 'numpy>=1.7', 'opt_einsum>=2.3.2', - 'pytypes', + # pin pytypes to master + 'git+https://github.com/Stewori/pytypes.git@master#egg=pytypes', ], extras_require={ 'torch': [ From 664f3ed72c208766183452e3d8ab5995e69116c4 Mon Sep 17 00:00:00 2001 From: Eli Date: Wed, 20 Jan 2021 22:36:18 -0500 Subject: [PATCH 06/66] attempt to work around need to change multipledispatch and revert changes to op --- funsor/ops/op.py | 7 +------ funsor/registry.py | 36 ++++++++++++++++++++++-------------- funsor/util.py | 13 +++++++++++++ setup.py | 3 +-- 4 files changed, 37 insertions(+), 22 deletions(-) diff --git a/funsor/ops/op.py b/funsor/ops/op.py index fe0a9a16f..700c70eae 100644 --- a/funsor/ops/op.py +++ b/funsor/ops/op.py @@ -7,8 +7,6 @@ from multipledispatch import Dispatcher -from funsor.util import _type_to_typing - class CachedOpMeta(type): """ @@ -41,10 +39,7 @@ def __init__(self, fn, *, name=None): # register as default operation for nargs in (1, 2): default_signature = (object,) * nargs - self.register(*default_signature)(fn) - - def register(self, *types): - return super().register(*tuple(map(_type_to_typing, types))) + self.add(default_signature, fn) # Register all existing patterns. for supercls in reversed(inspect.getmro(type(self))): diff --git a/funsor/registry.py b/funsor/registry.py index 9401ceabf..e560b9ddf 100644 --- a/funsor/registry.py +++ b/funsor/registry.py @@ -9,13 +9,17 @@ from multipledispatch import Dispatcher from multipledispatch.conflict import supercedes -from funsor.util import _type_to_typing +from funsor.util import _type_to_typing, get_origin, typing_wrap class PartialDispatcher(Dispatcher): """ Wrapper to avoid appearance in stack traces. """ + def __init__(self, name, default=None): + self.default = default if default is None else PartialDefault(default) + super().__init__(name) + def partial_call(self, *args): """ Likde :meth:`__call__` but avoids calling ``func()``. @@ -33,6 +37,14 @@ def partial_call(self, *args): self._cache[types] = func return func + def register(self, *types): + types = tuple(typing_wrap[tp] for tp in map(_type_to_typing, types)) + if self.default: + objects = (typing_wrap[typing.Any],) * len(types) + if objects != types and safe_supercedes(types, objects): + super().register(*objects)(self.default) + return super().register(*types) + class PartialDefault: def __init__(self, default): @@ -46,23 +58,19 @@ def partial_call(self, *args): return self.default +def safe_supercedes(xs, ys): + return supercedes(tuple(typing_wrap[_type_to_typing(x)] for x in xs), + tuple(typing_wrap[_type_to_typing(y)] for y in ys)) + + class KeyedRegistry(object): def __init__(self, default=None): - self.default = default if default is None else PartialDefault(default) - self.registry = defaultdict(lambda: PartialDispatcher('f')) + self.registry = defaultdict(lambda: PartialDispatcher('f', default=default)) + self.default = PartialDefault(default) if default is not None else default def register(self, key, *types): - types = tuple(map(_type_to_typing, types)) - key = getattr(key, "__origin__", key) - register = self.registry[key].register - if self.default: - objects = (typing.Any,) * len(types) - try: - if objects != types and supercedes(types, objects): - register(*objects)(self.default) - except TypeError: - pass # mysterious source of ambiguity in Python 3.5 breaks this + register = self.registry[get_origin(key)].register # This decorator supports stacking multiple decorators, which is not # supported by multipledipatch (which returns a Dispatch object rather @@ -77,7 +85,7 @@ def __contains__(self, key): return key in self.registry def __getitem__(self, key): - key = getattr(key, "__origin__", key) + key = get_origin(key) if self.default is None: return self.registry[key] return self.registry.get(key, self.default) diff --git a/funsor/util.py b/funsor/util.py index cd5f60eef..ffc03c667 100644 --- a/funsor/util.py +++ b/funsor/util.py @@ -284,6 +284,7 @@ def __getitem__(cls, arg_types): cls._type_cache[arg_types] = type(cls)(cls.__name__, (cls,), new_dct) return cls._type_cache[arg_types] + @functools.lru_cache(maxsize=None) def __subclasscheck__(cls, subcls): # issubclass(subcls, cls) if cls is subcls: return True @@ -308,3 +309,15 @@ def __repr__(cls): @lazy_property def classname(cls): return repr(cls) + + +class _PytypesSubclasser(GenericTypeMeta): + def __getitem__(cls, tp): + return tp if isinstance(tp, GenericTypeMeta) or isvariadic(tp) else super().__getitem__(tp) + + def __subclasscheck__(cls, subcls): + return pytypes.is_subtype(subcls, cls.__args__[0]) + + +class typing_wrap(metaclass=_PytypesSubclasser): + pass diff --git a/setup.py b/setup.py index cdf3c0ad7..430108dfa 100644 --- a/setup.py +++ b/setup.py @@ -35,8 +35,7 @@ python_requires=">=3.6", install_requires=[ 'makefun', - # use a fork of multipledispatch that depends on pytypes - 'git+https://github.com/eb8680/multipledispatch.git@pytypes-master#egg=multipledispatch', + 'multipledispatch', 'numpy>=1.7', 'opt_einsum>=2.3.2', # pin pytypes to master From d850d07a9d8fe0e52d68c4126bf4673825f15ab0 Mon Sep 17 00:00:00 2001 From: Eli Date: Thu, 21 Jan 2021 03:02:53 -0500 Subject: [PATCH 07/66] refactor Contraction to non-variadic patterns and remove all uses of variadic --- funsor/cnf.py | 80 +++++++++++++++++------------------ funsor/delta.py | 6 ++- funsor/distribution.py | 15 ++++--- funsor/gaussian.py | 4 +- funsor/integrate.py | 9 ++-- funsor/interpreter.py | 1 + funsor/jax/distributions.py | 10 ++--- funsor/joint.py | 21 +++------ funsor/optimizer.py | 20 +++------ funsor/registry.py | 4 +- funsor/tensor.py | 15 ++++--- funsor/terms.py | 26 +++--------- funsor/torch/distributions.py | 10 ++--- funsor/util.py | 6 +++ 14 files changed, 106 insertions(+), 121 deletions(-) diff --git a/funsor/cnf.py b/funsor/cnf.py index 71a9cc59c..dfddd0099 100644 --- a/funsor/cnf.py +++ b/funsor/cnf.py @@ -8,7 +8,6 @@ from typing import Tuple, Union import opt_einsum -from multipledispatch.variadic import Variadic import funsor import funsor.ops as ops @@ -16,13 +15,14 @@ from funsor.delta import Delta from funsor.domains import find_domain from funsor.gaussian import Gaussian -from funsor.interpreter import interpretation, recursion_reinterpret +from funsor.interpreter import interpretation from funsor.ops import DISTRIBUTIVE_OPS, AssociativeOp, NullOp, nullop from funsor.tensor import Tensor from funsor.terms import ( Align, Binary, Funsor, + FunsorMeta, Number, Reduce, Subs, @@ -36,7 +36,15 @@ from funsor.util import broadcast_shape, get_backend, quote -class Contraction(Funsor): +class ContractionMeta(FunsorMeta): + + def __call__(self, red_op, bin_op, reduced_vars, *terms): + if len(terms) == 1 and isinstance(terms[0], tuple): + terms, = terms + return super().__call__(red_op, bin_op, reduced_vars, tuple(terms)) + + +class Contraction(Funsor, metaclass=ContractionMeta): """ Declarative representation of a finitary sum-product operation. @@ -177,17 +185,7 @@ def _(arg, indent, out): out[-1] = i, line + ")" -@recursion_reinterpret.register(Contraction) -def recursion_reinterpret_contraction(x): - return type(x)(*map(recursion_reinterpret, (x.red_op, x.bin_op, x.reduced_vars) + x.terms)) - - -@eager.register(Contraction, AssociativeOp, AssociativeOp, frozenset, Variadic[Funsor]) -def eager_contraction_generic_to_tuple(red_op, bin_op, reduced_vars, *terms): - return eager(Contraction, red_op, bin_op, reduced_vars, terms) - - -@eager.register(Contraction, AssociativeOp, AssociativeOp, frozenset, tuple) +@eager.register(Contraction, AssociativeOp, AssociativeOp, frozenset, Tuple[Funsor, ...]) def eager_contraction_generic_recursive(red_op, bin_op, reduced_vars, terms): # Count the number of terms in which each variable is reduced. counts = Counter() @@ -228,15 +226,16 @@ def eager_contraction_generic_recursive(red_op, bin_op, reduced_vars, terms): return None -@eager.register(Contraction, AssociativeOp, AssociativeOp, frozenset, Funsor) -def eager_contraction_to_reduce(red_op, bin_op, reduced_vars, term): +@eager.register(Contraction, AssociativeOp, AssociativeOp, frozenset, Tuple[Funsor]) +def eager_contraction_to_reduce(red_op, bin_op, reduced_vars, terms): + term, = terms args = red_op, term, reduced_vars return eager.dispatch(Reduce, *args)(*args) -@eager.register(Contraction, AssociativeOp, AssociativeOp, frozenset, Funsor, Funsor) -def eager_contraction_to_binary(red_op, bin_op, reduced_vars, lhs, rhs): - +@eager.register(Contraction, AssociativeOp, AssociativeOp, frozenset, Tuple[Funsor, Funsor]) +def eager_contraction_to_binary(red_op, bin_op, reduced_vars, terms): + lhs, rhs = terms if not reduced_vars.issubset(lhs.input_vars & rhs.input_vars): args = red_op, bin_op, reduced_vars, (lhs, rhs) result = eager.dispatch(Contraction, *args)(*args) @@ -251,16 +250,16 @@ def eager_contraction_to_binary(red_op, bin_op, reduced_vars, lhs, rhs): return result -@eager.register(Contraction, ops.AddOp, ops.MulOp, frozenset, Tensor, Tensor) -def eager_contraction_tensor(red_op, bin_op, reduced_vars, *terms): +@eager.register(Contraction, ops.AddOp, ops.MulOp, frozenset, Tuple[Tensor, Tensor]) +def eager_contraction_tensor(red_op, bin_op, reduced_vars, terms): if not all(term.dtype == "real" for term in terms): raise NotImplementedError('TODO') backend = BACKEND_TO_EINSUM_BACKEND[get_backend()] return _eager_contract_tensors(reduced_vars, terms, backend=backend) -@eager.register(Contraction, ops.LogaddexpOp, ops.AddOp, frozenset, Tensor, Tensor) -def eager_contraction_tensor(red_op, bin_op, reduced_vars, *terms): +@eager.register(Contraction, ops.LogaddexpOp, ops.AddOp, frozenset, Tuple[Tensor, Tensor]) +def eager_contraction_tensor(red_op, bin_op, reduced_vars, terms): if not all(term.dtype == "real" for term in terms): raise NotImplementedError('TODO') backend = BACKEND_TO_LOGSUMEXP_BACKEND[get_backend()] @@ -307,8 +306,9 @@ def _eager_contract_tensors(reduced_vars, terms, backend): # Pyro's gaussian_tensordot() here. Until then we must eagerly add the # possibly-rank-deficient terms before reducing to avoid Cholesky errors. @eager.register(Contraction, ops.LogaddexpOp, ops.AddOp, frozenset, - GaussianMixture, GaussianMixture) -def eager_contraction_gaussian(red_op, bin_op, reduced_vars, x, y): + Tuple[GaussianMixture, GaussianMixture]) +def eager_contraction_gaussian(red_op, bin_op, reduced_vars, terms): + x, y = terms return (x + y).reduce(red_op, reduced_vars) @@ -324,11 +324,11 @@ def _(fn): ########################################## ORDERING = {Delta: 1, Number: 2, Tensor: 3, Gaussian: 4} -GROUND_TERMS = tuple(ORDERING) +GROUND_TERMS = Union[Delta, Number, Tensor, Gaussian] -@normalize.register(Contraction, AssociativeOp, ops.AddOp, frozenset, GROUND_TERMS, GROUND_TERMS) -def normalize_contraction_commutative_canonical_order(red_op, bin_op, reduced_vars, *terms): +@normalize.register(Contraction, AssociativeOp, ops.AddOp, frozenset, Tuple[GROUND_TERMS, GROUND_TERMS]) +def normalize_contraction_commutative_canonical_order(red_op, bin_op, reduced_vars, terms): # when bin_op is commutative, put terms into a canonical order for pattern matching new_terms = tuple( v for i, v in sorted(enumerate(terms), @@ -336,33 +336,31 @@ def normalize_contraction_commutative_canonical_order(red_op, bin_op, reduced_va ) if any(v is not vv for v, vv in zip(terms, new_terms)): return Contraction(red_op, bin_op, reduced_vars, *new_terms) - return normalize(Contraction, red_op, bin_op, reduced_vars, new_terms) + return None # normalize(Contraction, red_op, bin_op, reduced_vars, new_terms) -@normalize.register(Contraction, AssociativeOp, ops.AddOp, frozenset, GaussianMixture, GROUND_TERMS) -def normalize_contraction_commute_joint(red_op, bin_op, reduced_vars, mixture, other): +@normalize.register(Contraction, AssociativeOp, ops.AddOp, frozenset, Tuple[GaussianMixture, GROUND_TERMS]) +def normalize_contraction_commute_joint(red_op, bin_op, reduced_vars, terms): + mixture, other = terms return Contraction(mixture.red_op if red_op is nullop else red_op, bin_op, reduced_vars | mixture.reduced_vars, *(mixture.terms + (other,))) -@normalize.register(Contraction, AssociativeOp, ops.AddOp, frozenset, GROUND_TERMS, GaussianMixture) -def normalize_contraction_commute_joint(red_op, bin_op, reduced_vars, other, mixture): +@normalize.register(Contraction, AssociativeOp, ops.AddOp, frozenset, Tuple[GROUND_TERMS, GaussianMixture]) +def normalize_contraction_commute_joint(red_op, bin_op, reduced_vars, terms): + other, mixture = terms return Contraction(mixture.red_op if red_op is nullop else red_op, bin_op, reduced_vars | mixture.reduced_vars, *(mixture.terms + (other,))) -@normalize.register(Contraction, AssociativeOp, AssociativeOp, frozenset, Variadic[Funsor]) -def normalize_contraction_generic_args(red_op, bin_op, reduced_vars, *terms): - return normalize(Contraction, red_op, bin_op, reduced_vars, tuple(terms)) - - -@normalize.register(Contraction, NullOp, NullOp, frozenset, Funsor) -def normalize_trivial(red_op, bin_op, reduced_vars, term): +@normalize.register(Contraction, NullOp, NullOp, frozenset, Tuple[Funsor]) +def normalize_trivial(red_op, bin_op, reduced_vars, terms): + term, = terms assert not reduced_vars return term -@normalize.register(Contraction, AssociativeOp, AssociativeOp, frozenset, tuple) +@normalize.register(Contraction, AssociativeOp, AssociativeOp, frozenset, Tuple[Funsor, ...]) def normalize_contraction_generic_tuple(red_op, bin_op, reduced_vars, terms): if not reduced_vars and red_op is not nullop: diff --git a/funsor/delta.py b/funsor/delta.py index 1a37d3724..eec0f0339 100644 --- a/funsor/delta.py +++ b/funsor/delta.py @@ -187,7 +187,8 @@ def eager_add_multidelta(op, lhs, rhs): return Delta(lhs.terms + rhs.terms) -@eager.register(Binary, (AddOp, SubOp), Delta, (Funsor, Align)) +@eager.register(Binary, (AddOp, SubOp), Delta, Align) +@eager.register(Binary, (AddOp, SubOp), Delta, Funsor) def eager_add_delta_funsor(op, lhs, rhs): if lhs.fresh.intersection(rhs.inputs): rhs = rhs(**{name: point for name, (point, log_density) in lhs.terms if name in rhs.inputs}) @@ -196,7 +197,8 @@ def eager_add_delta_funsor(op, lhs, rhs): return None # defer to default implementation -@eager.register(Binary, AddOp, (Funsor, Align), Delta) +@eager.register(Binary, AddOp, Align, Delta) +@eager.register(Binary, AddOp, Funsor, Delta) def eager_add_funsor_delta(op, lhs, rhs): if rhs.fresh.intersection(lhs.inputs): lhs = lhs(**{name: point for name, (point, log_density) in rhs.terms if name in lhs.inputs}) diff --git a/funsor/distribution.py b/funsor/distribution.py index 81af49460..3789c806a 100644 --- a/funsor/distribution.py +++ b/funsor/distribution.py @@ -718,13 +718,15 @@ def eager_mvn(loc, scale_tril, value): return gaussian(**{var: value - loc}) -def eager_beta_bernoulli(red_op, bin_op, reduced_vars, x, y): +def eager_beta_bernoulli(red_op, bin_op, reduced_vars, xy): + x, y = xy backend_dist = import_module(BACKEND_TO_DISTRIBUTIONS_BACKEND[get_backend()]) return eager_dirichlet_multinomial(red_op, bin_op, reduced_vars, x, backend_dist.Binomial(total_count=1, probs=y.probs, value=y.value)) -def eager_dirichlet_categorical(red_op, bin_op, reduced_vars, x, y): +def eager_dirichlet_categorical(red_op, bin_op, reduced_vars, xy): + x, y = xy dirichlet_reduction = x.input_vars & reduced_vars if dirichlet_reduction: backend_dist = import_module(BACKEND_TO_DISTRIBUTIONS_BACKEND[get_backend()]) @@ -736,7 +738,8 @@ def eager_dirichlet_categorical(red_op, bin_op, reduced_vars, x, y): return eager(Contraction, red_op, bin_op, reduced_vars, (x, y)) -def eager_dirichlet_multinomial(red_op, bin_op, reduced_vars, x, y): +def eager_dirichlet_multinomial(red_op, bin_op, reduced_vars, xy): + x, y = xy dirichlet_reduction = x.input_vars & reduced_vars if dirichlet_reduction: backend_dist = import_module(BACKEND_TO_DISTRIBUTIONS_BACKEND[get_backend()]) @@ -769,7 +772,8 @@ def _log_beta(x, y): return ops.lgamma(x) + ops.lgamma(y) - ops.lgamma(x + y) -def eager_gamma_gamma(red_op, bin_op, reduced_vars, x, y): +def eager_gamma_gamma(red_op, bin_op, reduced_vars, xy): + x, y = xy gamma_reduction = x.input_vars & reduced_vars if gamma_reduction: unnormalized = (y.concentration - 1) * ops.log(y.value) \ @@ -780,7 +784,8 @@ def eager_gamma_gamma(red_op, bin_op, reduced_vars, x, y): return eager(Contraction, red_op, bin_op, reduced_vars, (x, y)) -def eager_gamma_poisson(red_op, bin_op, reduced_vars, x, y): +def eager_gamma_poisson(red_op, bin_op, reduced_vars, xy): + x, y = xy gamma_reduction = x.input_vars & reduced_vars if gamma_reduction: backend_dist = import_module(BACKEND_TO_DISTRIBUTIONS_BACKEND[get_backend()]) diff --git a/funsor/gaussian.py b/funsor/gaussian.py index 8c74cf0aa..f7a53ed08 100644 --- a/funsor/gaussian.py +++ b/funsor/gaussian.py @@ -645,8 +645,8 @@ def eager_add_gaussian_gaussian(op, lhs, rhs): return Gaussian(info_vec, precision, inputs) -@eager.register(Binary, SubOp, Gaussian, (Funsor, Align, Gaussian)) -@eager.register(Binary, SubOp, (Funsor, Align, Delta), Gaussian) +# @eager.register(Binary, SubOp, Gaussian, (Funsor, Align, Gaussian)) +# @eager.register(Binary, SubOp, (Funsor, Align, Delta), Gaussian) def eager_sub(op, lhs, rhs): return lhs + -rhs diff --git a/funsor/integrate.py b/funsor/integrate.py index 3dce6a5f9..7c9623d10 100644 --- a/funsor/integrate.py +++ b/funsor/integrate.py @@ -2,7 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 from collections import OrderedDict -from typing import Union +from typing import Tuple, Union import funsor.ops as ops from funsor.cnf import Contraction, GaussianMixture @@ -91,9 +91,10 @@ def normalize_integrate_contraction(log_measure, integrand, reduced_vars): @eager.register(Contraction, ops.AddOp, ops.MulOp, frozenset, - Unary[ops.ExpOp, Union[GaussianMixture, Delta, Gaussian, Number, Tensor]], - (Variable, Delta, Gaussian, Number, Tensor, GaussianMixture)) -def eager_contraction_binary_to_integrate(red_op, bin_op, reduced_vars, lhs, rhs): + Tuple[Unary[ops.ExpOp, Union[GaussianMixture, Delta, Gaussian, Number, Tensor]], + Union[Variable, Delta, Gaussian, Number, Tensor, GaussianMixture]]) +def eager_contraction_binary_to_integrate(red_op, bin_op, reduced_vars, terms): + lhs, rhs = terms reduced_names = frozenset(v.name for v in reduced_vars) if not (reduced_names.issubset(lhs.inputs) and reduced_names.issubset(rhs.inputs)): args = red_op, bin_op, reduced_vars, (lhs, rhs) diff --git a/funsor/interpreter.py b/funsor/interpreter.py index bc9e5b5d6..f06810037 100644 --- a/funsor/interpreter.py +++ b/funsor/interpreter.py @@ -330,6 +330,7 @@ def dispatched_interpretation(fn): else: fn.register = registry.register fn.dispatch = registry.dispatch + fn._registry = registry return fn diff --git a/funsor/jax/distributions.py b/funsor/jax/distributions.py index 15f685304..94e37afb2 100644 --- a/funsor/jax/distributions.py +++ b/funsor/jax/distributions.py @@ -278,15 +278,15 @@ def deltadist_to_funsor(pyro_dist, output=None, dim_to_name=None): eager.register(Delta, Variable, Variable, Variable)(eager_delta_variable_variable) # noqa: F821 eager.register(Normal, Funsor, Tensor, Funsor)(eager_normal) # noqa: F821 eager.register(MultivariateNormal, Funsor, Tensor, Funsor)(eager_mvn) # noqa: F821 -eager.register(Contraction, ops.LogaddexpOp, ops.AddOp, frozenset, Dirichlet, BernoulliProbs)( # noqa: F821 +eager.register(Contraction, ops.LogaddexpOp, ops.AddOp, frozenset, Tuple[Dirichlet, BernoulliProbs])( # noqa: F821 eager_beta_bernoulli) -eager.register(Contraction, ops.LogaddexpOp, ops.AddOp, frozenset, Dirichlet, Categorical)( # noqa: F821 +eager.register(Contraction, ops.LogaddexpOp, ops.AddOp, frozenset, Tuple[Dirichlet, Categorical])( # noqa: F821 eager_dirichlet_categorical) -eager.register(Contraction, ops.LogaddexpOp, ops.AddOp, frozenset, Dirichlet, Multinomial)( # noqa: F821 +eager.register(Contraction, ops.LogaddexpOp, ops.AddOp, frozenset, Tuple[Dirichlet, Multinomial])( # noqa: F821 eager_dirichlet_multinomial) -eager.register(Contraction, ops.LogaddexpOp, ops.AddOp, frozenset, Gamma, Gamma)( # noqa: F821 +eager.register(Contraction, ops.LogaddexpOp, ops.AddOp, frozenset, Tuple[Gamma, Gamma])( # noqa: F821 eager_gamma_gamma) -eager.register(Contraction, ops.LogaddexpOp, ops.AddOp, frozenset, Gamma, Poisson)( # noqa: F821 +eager.register(Contraction, ops.LogaddexpOp, ops.AddOp, frozenset, Tuple[Gamma, Poisson])( # noqa: F821 eager_gamma_poisson) if hasattr(dist, "DirichletMultinomial"): eager.register(Binary, ops.SubOp, JointDirichletMultinomial, DirichletMultinomial)( # noqa: F821 diff --git a/funsor/joint.py b/funsor/joint.py index 09fac9084..2c085ddc5 100644 --- a/funsor/joint.py +++ b/funsor/joint.py @@ -6,21 +6,17 @@ from functools import reduce from typing import Tuple, Union -from multipledispatch import dispatch -from multipledispatch.variadic import Variadic - import funsor.ops as ops from funsor.cnf import Contraction, GaussianMixture from funsor.delta import Delta from funsor.domains import Bint from funsor.gaussian import Gaussian, align_gaussian -from funsor.ops import AssociativeOp from funsor.tensor import Tensor, align_tensor -from funsor.terms import Funsor, Independent, Number, Reduce, Unary, eager, moment_matching, normalize +from funsor.terms import Cat, Funsor, Independent, Number, Reduce, Unary, eager, moment_matching, normalize -@dispatch(str, str, Variadic[(Gaussian, GaussianMixture)]) -def eager_cat_homogeneous(name, part_name, *parts): +@eager.register(Cat, str, Tuple[Union[Gaussian, GaussianMixture], ...], str) +def eager_cat_gaussian(name, parts, part_name): assert parts output = parts[0].output inputs = OrderedDict([(part_name, None)]) @@ -75,14 +71,9 @@ def eager_cat_homogeneous(name, part_name, *parts): # patterns for moment-matching ################################# -@moment_matching.register(Contraction, AssociativeOp, AssociativeOp, frozenset, Variadic[object]) -def moment_matching_contract_default(*args): - return None - - -@moment_matching.register(Contraction, ops.LogaddexpOp, ops.AddOp, frozenset, (Number, Tensor), Gaussian) -def moment_matching_contract_joint(red_op, bin_op, reduced_vars, discrete, gaussian): - +@moment_matching.register(Contraction, ops.LogaddexpOp, ops.AddOp, frozenset, Tuple[Union[Number, Tensor], Gaussian]) +def moment_matching_contract_joint(red_op, bin_op, reduced_vars, terms): + discrete, gaussian = terms approx_vars = frozenset(v for v in reduced_vars if v.name in gaussian.inputs if v.dtype != 'real') diff --git a/funsor/optimizer.py b/funsor/optimizer.py index 9c63a8606..39f57f34a 100644 --- a/funsor/optimizer.py +++ b/funsor/optimizer.py @@ -2,8 +2,8 @@ # SPDX-License-Identifier: Apache-2.0 import collections +from typing import Tuple -from multipledispatch.variadic import Variadic from opt_einsum.paths import greedy import funsor.interpreter as interpreter @@ -22,7 +22,7 @@ def unfold(cls, *args): return result -@unfold.register(Contraction, AssociativeOp, AssociativeOp, frozenset, tuple) +@unfold.register(Contraction, AssociativeOp, AssociativeOp, frozenset, Tuple[Funsor, ...]) def unfold_contraction_generic_tuple(red_op, bin_op, reduced_vars, terms): for i, v in enumerate(terms): @@ -50,10 +50,6 @@ def unfold_contraction_generic_tuple(red_op, bin_op, reduced_vars, terms): return None -unfold.register(Contraction, AssociativeOp, AssociativeOp, frozenset, Variadic[Funsor])( - lambda r, b, v, *ts: unfold(Contraction, r, b, v, tuple(ts))) - - @interpreter.dispatched_interpretation def optimize(cls, *args): result = optimize.dispatch(cls, *args)(*args) @@ -66,17 +62,13 @@ def optimize(cls, *args): REAL_SIZE = 3 # the "size" of a real-valued dimension passed to the path optimizer -optimize.register(Contraction, AssociativeOp, AssociativeOp, frozenset, Variadic[Funsor])( - lambda r, b, v, *ts: optimize(Contraction, r, b, v, tuple(ts))) - - -@optimize.register(Contraction, AssociativeOp, AssociativeOp, frozenset, Funsor, Funsor) -@optimize.register(Contraction, AssociativeOp, AssociativeOp, frozenset, Funsor) -def eager_contract_base(red_op, bin_op, reduced_vars, *terms): +@optimize.register(Contraction, AssociativeOp, AssociativeOp, frozenset, Tuple[Funsor, Funsor]) +@optimize.register(Contraction, AssociativeOp, AssociativeOp, frozenset, Tuple[Funsor]) +def eager_contract_base(red_op, bin_op, reduced_vars, terms): return None -@optimize.register(Contraction, AssociativeOp, AssociativeOp, frozenset, tuple) +@optimize.register(Contraction, AssociativeOp, AssociativeOp, frozenset, Tuple[Funsor, ...]) def optimize_contract_finitary_funsor(red_op, bin_op, reduced_vars, terms): if red_op is nullop or bin_op is nullop or not (red_op, bin_op) in DISTRIBUTIVE_OPS: diff --git a/funsor/registry.py b/funsor/registry.py index e560b9ddf..8a008f787 100644 --- a/funsor/registry.py +++ b/funsor/registry.py @@ -9,7 +9,7 @@ from multipledispatch import Dispatcher from multipledispatch.conflict import supercedes -from funsor.util import _type_to_typing, get_origin, typing_wrap +from funsor.util import _type_to_typing, deep_type, get_origin, typing_wrap class PartialDispatcher(Dispatcher): @@ -24,7 +24,7 @@ def partial_call(self, *args): """ Likde :meth:`__call__` but avoids calling ``func()``. """ - types = tuple(map(type, args)) + types = tuple(map(deep_type, args)) types = tuple(map(_type_to_typing, types)) try: func = self._cache[types] diff --git a/funsor/tensor.py b/funsor/tensor.py index e4ec652fa..1152ad15f 100644 --- a/funsor/tensor.py +++ b/funsor/tensor.py @@ -11,8 +11,6 @@ import numpy as np import opt_einsum -from multipledispatch import dispatch -from multipledispatch.variadic import Variadic import funsor import funsor.ops as ops @@ -21,11 +19,13 @@ from funsor.ops import GetitemOp, MatmulOp, Op, ReshapeOp from funsor.terms import ( Binary, + Cat, Funsor, FunsorMeta, Lambda, Number, Slice, + Stack, Tuple, Unary, Variable, @@ -712,8 +712,8 @@ def eager_lambda(var, expr): return Tensor(data, inputs, expr.dtype) -@dispatch(str, Variadic[Tensor]) -def eager_stack_homogeneous(name, *parts): +@eager.register(Stack, str, typing.Tuple[Tensor, ...]) +def eager_stack_tensors(name, parts): assert parts output = parts[0].output part_inputs = OrderedDict() @@ -730,9 +730,12 @@ def eager_stack_homogeneous(name, *parts): return Tensor(data, inputs, dtype=output.dtype) -@dispatch(str, str, Variadic[Tensor]) -def eager_cat_homogeneous(name, part_name, *parts): +@eager.register(Cat, str, typing.Tuple[Tensor, ...], str) +def eager_cat_tensors(name, parts, part_name): assert parts + if len(parts) == 1: + return parts[0](**{part_name: name}) + output = parts[0].output inputs = OrderedDict([(part_name, None)]) for part in parts: diff --git a/funsor/terms.py b/funsor/terms.py index 33588ea77..c95cfa660 100644 --- a/funsor/terms.py +++ b/funsor/terms.py @@ -12,9 +12,6 @@ from functools import reduce, singledispatch from weakref import WeakValueDictionary -from multipledispatch import dispatch -from multipledispatch.variadic import Variadic - import funsor.interpreter as interpreter import funsor.ops as ops from funsor.domains import Array, Bint, Domain, Product, Real, find_domain @@ -1282,14 +1279,9 @@ def eager_reduce(self, op, reduced_vars): return Stack(self.name, parts) -@eager.register(Stack, str, tuple) +@eager.register(Stack, str, typing.Tuple[Funsor, ...]) def eager_stack(name, parts): - return eager_stack_homogeneous(name, *parts) - - -@dispatch(str, Variadic[Funsor]) -def eager_stack_homogeneous(name, *parts): - return None # defer to default implementation + return None class CatMeta(FunsorMeta): @@ -1381,16 +1373,9 @@ def eager_subs(self, subs): .format(type(value))) -@eager.register(Cat, str, tuple, str) +@eager.register(Cat, str, typing.Tuple[Funsor], str) def eager_cat(name, parts, part_name): - if len(parts) == 1: - return parts[0](**{part_name: name}) - return eager_cat_homogeneous(name, part_name, *parts) - - -@dispatch(str, str, Variadic[Funsor]) -def eager_cat_homogeneous(name, part_name, *parts): - return None # defer to default implementation + return parts[0](**{part_name: name}) class Lambda(Funsor): @@ -1423,7 +1408,8 @@ def _alpha_convert(self, alpha_subs): return super()._alpha_convert(alpha_subs) -@eager.register(Binary, GetitemOp, Lambda, (Funsor, Align)) +@eager.register(Binary, GetitemOp, Lambda, Align) +@eager.register(Binary, GetitemOp, Lambda, Funsor) def eager_getitem_lambda(op, lhs, rhs): if op.offset == 0: return Subs(lhs.expr, ((lhs.var.name, rhs),)) diff --git a/funsor/torch/distributions.py b/funsor/torch/distributions.py index 6fe1b3186..84d6f8e5b 100644 --- a/funsor/torch/distributions.py +++ b/funsor/torch/distributions.py @@ -345,15 +345,15 @@ def deltadist_to_funsor(pyro_dist, output=None, dim_to_name=None): eager.register(Delta, Variable, Variable, Variable)(eager_delta_variable_variable) # noqa: F821 eager.register(Normal, Funsor, Tensor, Funsor)(eager_normal) # noqa: F821 eager.register(MultivariateNormal, Funsor, Tensor, Funsor)(eager_mvn) # noqa: F821 -eager.register(Contraction, ops.LogaddexpOp, ops.AddOp, frozenset, Dirichlet, BernoulliProbs)( # noqa: F821 +eager.register(Contraction, ops.LogaddexpOp, ops.AddOp, frozenset, Tuple[Dirichlet, BernoulliProbs])( # noqa: F821 eager_beta_bernoulli) -eager.register(Contraction, ops.LogaddexpOp, ops.AddOp, frozenset, Dirichlet, Categorical)( # noqa: F821 +eager.register(Contraction, ops.LogaddexpOp, ops.AddOp, frozenset, Tuple[Dirichlet, Categorical])( # noqa: F821 eager_dirichlet_categorical) -eager.register(Contraction, ops.LogaddexpOp, ops.AddOp, frozenset, Dirichlet, Multinomial)( # noqa: F821 +eager.register(Contraction, ops.LogaddexpOp, ops.AddOp, frozenset, Tuple[Dirichlet, Multinomial])( # noqa: F821 eager_dirichlet_multinomial) -eager.register(Contraction, ops.LogaddexpOp, ops.AddOp, frozenset, Gamma, Gamma)( # noqa: F821 +eager.register(Contraction, ops.LogaddexpOp, ops.AddOp, frozenset, Tuple[Gamma, Gamma])( # noqa: F821 eager_gamma_gamma) -eager.register(Contraction, ops.LogaddexpOp, ops.AddOp, frozenset, Gamma, Poisson)( # noqa: F821 +eager.register(Contraction, ops.LogaddexpOp, ops.AddOp, frozenset, Tuple[Gamma, Poisson])( # noqa: F821 eager_gamma_poisson) eager.register(Binary, ops.SubOp, JointDirichletMultinomial, DirichletMultinomial)( # noqa: F821 eager_dirichlet_posterior) diff --git a/funsor/util.py b/funsor/util.py index ffc03c667..09b08eaff 100644 --- a/funsor/util.py +++ b/funsor/util.py @@ -258,6 +258,10 @@ def get_args(tp): return typing_extensions.get_args(tp) +def deep_type(obj): + return pytypes.deep_type(obj) + + class GenericTypeMeta(type): """ Metaclass to support subtyping with parameters for pattern matching, e.g. Number[int, int]. @@ -316,6 +320,8 @@ def __getitem__(cls, tp): return tp if isinstance(tp, GenericTypeMeta) or isvariadic(tp) else super().__getitem__(tp) def __subclasscheck__(cls, subcls): + if isinstance(subcls, _PytypesSubclasser): + subcls = subcls.__args__[0] return pytypes.is_subtype(subcls, cls.__args__[0]) From be0a4cba43d148e1ee911bd500b6e62cc01cba80 Mon Sep 17 00:00:00 2001 From: Eli Date: Thu, 21 Jan 2021 18:21:45 -0500 Subject: [PATCH 08/66] add funsor.typing module --- funsor/typing.py | 134 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 134 insertions(+) create mode 100644 funsor/typing.py diff --git a/funsor/typing.py b/funsor/typing.py new file mode 100644 index 000000000..93b33ca1b --- /dev/null +++ b/funsor/typing.py @@ -0,0 +1,134 @@ +# Copyright Contributors to the Pyro project. +# SPDX-License-Identifier: Apache-2.0 + +import functools +import typing +import typing_extensions +import weakref + +import pytypes + +from multipledispatch.conflict import supercedes +from multipledispatch.dispatcher import Dispatcher +from multipledispatch.variadic import isvariadic + + +def _type_to_typing(tp): + if tp is object: + tp = typing.Any + if isinstance(tp, tuple): + tp = typing.Union[tuple(map(_type_to_typing, tp))] + return tp + + +def get_origin(tp): + if isinstance(tp, GenericTypeMeta): + return getattr(tp, "__origin__", tp) + return typing_extensions.get_origin(tp) + + +def get_args(tp): + if isinstance(tp, GenericTypeMeta): + return getattr(tp, "__args__", tp) + return typing_extensions.get_args(tp) + + +def get_type_hints(obj, globalns=None, localns=None, include_extras=False): + if isinstance(obj, GenericTypeMeta) and hasattr(obj, "__annotations__"): + return obj.__annotations__ + return typing_extensions.get_type_hints(obj, globalns=globalns, localns=localns, include_extras=include_extras) + + +def deep_type(obj): + return pytypes.deep_type(obj) + + +def deep_issubclass(subcls, cls): + return pytypes.is_subtype(subcls, cls) + + +def deep_isinstance(obj, cls): + return pytypes.is_of_type(obj, cls) + + +def deep_supercedes(xs, ys): + return supercedes(tuple(typing_wrap[_type_to_typing(x)] for x in xs), + tuple(typing_wrap[_type_to_typing(y)] for y in ys)) + + +class GenericTypeMeta(type): + """ + Metaclass to support subtyping with parameters for pattern matching, e.g. Number[int, int]. + """ + def __init__(cls, name, bases, dct): + super().__init__(name, bases, dct) + if not hasattr(cls, "__args__"): + cls.__args__ = () + if cls.__args__: + base, = bases + cls.__origin__ = base + else: + cls._type_cache = weakref.WeakValueDictionary() + + def __getitem__(cls, arg_types): + if not isinstance(arg_types, tuple): + arg_types = (arg_types,) + assert not any(isvariadic(arg_type) for arg_type in arg_types), "nested variadic types not supported" + if arg_types not in cls._type_cache: + assert not get_args(cls), "cannot subscript a subscripted type {}".format(cls) + new_dct = cls.__dict__.copy() + new_dct.update({"__args__": arg_types}) + # type(cls) to handle GenericTypeMeta subclasses + cls._type_cache[arg_types] = type(cls)(cls.__name__, (cls,), new_dct) + return cls._type_cache[arg_types] + + @functools.lru_cache(maxsize=None) + def __subclasscheck__(cls, subcls): # issubclass(subcls, cls) + if cls is subcls: + return True + + if not isinstance(subcls, GenericTypeMeta): + return super(GenericTypeMeta, get_origin(cls)).__subclasscheck__(subcls) + + if not super(GenericTypeMeta, get_origin(cls)).__subclasscheck__(get_origin(subcls)): + return False + + if len(get_args(cls)) != len(get_args(subcls)): + return len(get_args(cls)) == 0 + + return all(deep_issubclass(_type_to_typing(ps), _type_to_typing(pc)) + for ps, pc in zip(get_args(subcls), get_args(cls))) + + def __repr__(cls): + return get_origin(cls).__name__ + ( + "" if not get_args(cls) else + "[{}]".format(", ".join(repr(t) for t in get_args(cls)))) + + @lazy_property + def classname(cls): + return repr(cls) + + +class _PytypesSubclasser(GenericTypeMeta): + def __getitem__(cls, tp): + return tp if isinstance(tp, GenericTypeMeta) or isvariadic(tp) else super().__getitem__(tp) + + def __subclasscheck__(cls, subcls): + if isinstance(subcls, _PytypesSubclasser): + subcls = subcls.__args__[0] + return deep_issubclass(subcls, cls.__args__[0]) + + +class typing_wrap(metaclass=_PytypesSubclasser): + pass + + +class TypingDispatcher(Dispatcher): + + def register(self, *types): + types = tuple(typing_wrap[tp] for tp in map(_type_to_typing, types)) + if self.default: + objects = (typing_wrap[typing.Any],) * len(types) + if objects != types and deep_supercedes(types, objects): + super().register(*objects)(self.default) + return super().register(*types) From b2931f4c1e549722aa2a508697641bb9b66bf636 Mon Sep 17 00:00:00 2001 From: Eli Date: Thu, 21 Jan 2021 18:21:53 -0500 Subject: [PATCH 09/66] use funsor.typing --- funsor/domains.py | 3 +- funsor/registry.py | 23 +---------- funsor/terms.py | 3 +- funsor/typing.py | 4 +- funsor/util.py | 96 ---------------------------------------------- 5 files changed, 8 insertions(+), 121 deletions(-) diff --git a/funsor/domains.py b/funsor/domains.py index 22568700f..c950de948 100644 --- a/funsor/domains.py +++ b/funsor/domains.py @@ -8,7 +8,8 @@ from functools import reduce import funsor.ops as ops -from funsor.util import GenericTypeMeta, broadcast_shape, get_backend, get_tracing_state, quote +from funsor.typing import GenericTypeMeta +from funsor.util import broadcast_shape, get_backend, get_tracing_state, quote class Domain(GenericTypeMeta): diff --git a/funsor/registry.py b/funsor/registry.py index 8a008f787..e4f362418 100644 --- a/funsor/registry.py +++ b/funsor/registry.py @@ -1,18 +1,12 @@ # Copyright Contributors to the Pyro project. # SPDX-License-Identifier: Apache-2.0 -import functools -import typing from collections import defaultdict -import pytypes -from multipledispatch import Dispatcher -from multipledispatch.conflict import supercedes +from funsor.typing import TypingDispatcher, _type_to_typing, deep_type, get_origin -from funsor.util import _type_to_typing, deep_type, get_origin, typing_wrap - -class PartialDispatcher(Dispatcher): +class PartialDispatcher(TypingDispatcher): """ Wrapper to avoid appearance in stack traces. """ @@ -37,14 +31,6 @@ def partial_call(self, *args): self._cache[types] = func return func - def register(self, *types): - types = tuple(typing_wrap[tp] for tp in map(_type_to_typing, types)) - if self.default: - objects = (typing_wrap[typing.Any],) * len(types) - if objects != types and safe_supercedes(types, objects): - super().register(*objects)(self.default) - return super().register(*types) - class PartialDefault: def __init__(self, default): @@ -58,11 +44,6 @@ def partial_call(self, *args): return self.default -def safe_supercedes(xs, ys): - return supercedes(tuple(typing_wrap[_type_to_typing(x)] for x in xs), - tuple(typing_wrap[_type_to_typing(y)] for y in ys)) - - class KeyedRegistry(object): def __init__(self, default=None): diff --git a/funsor/terms.py b/funsor/terms.py index c95cfa660..334b444c2 100644 --- a/funsor/terms.py +++ b/funsor/terms.py @@ -17,7 +17,8 @@ from funsor.domains import Array, Bint, Domain, Product, Real, find_domain from funsor.interpreter import PatternMissingError, dispatched_interpretation, interpret from funsor.ops import AssociativeOp, GetitemOp, Op -from funsor.util import GenericTypeMeta, getargspec, get_backend, lazy_property, pretty, quote, _type_to_typing +from funsor.typing import GenericTypeMeta, _type_to_typing +from funsor.util import getargspec, get_backend, lazy_property, pretty, quote def substitute(expr, subs): diff --git a/funsor/typing.py b/funsor/typing.py index 93b33ca1b..215a4d989 100644 --- a/funsor/typing.py +++ b/funsor/typing.py @@ -104,7 +104,7 @@ def __repr__(cls): "" if not get_args(cls) else "[{}]".format(", ".join(repr(t) for t in get_args(cls)))) - @lazy_property + @property def classname(cls): return repr(cls) @@ -124,7 +124,7 @@ class typing_wrap(metaclass=_PytypesSubclasser): class TypingDispatcher(Dispatcher): - + def register(self, *types): types = tuple(typing_wrap[tp] for tp in map(_type_to_typing, types)) if self.default: diff --git a/funsor/util.py b/funsor/util.py index 09b08eaff..44234808b 100644 --- a/funsor/util.py +++ b/funsor/util.py @@ -4,13 +4,8 @@ import functools import inspect import re -import typing -import typing_extensions -import weakref -import pytypes import numpy as np -from multipledispatch.variadic import isvariadic _FUNSOR_BACKEND = "numpy" @@ -236,94 +231,3 @@ def decorator(fn): setattr(cls, name_, fn) return fn return decorator - - -def _type_to_typing(tp): - if tp is object: - tp = typing.Any - if isinstance(tp, tuple): - tp = typing.Union[tuple(map(_type_to_typing, tp))] - return tp - - -def get_origin(tp): - if isinstance(tp, GenericTypeMeta): - return getattr(tp, "__origin__", tp) - return typing_extensions.get_origin(tp) - - -def get_args(tp): - if isinstance(tp, GenericTypeMeta): - return getattr(tp, "__args__", tp) - return typing_extensions.get_args(tp) - - -def deep_type(obj): - return pytypes.deep_type(obj) - - -class GenericTypeMeta(type): - """ - Metaclass to support subtyping with parameters for pattern matching, e.g. Number[int, int]. - """ - def __init__(cls, name, bases, dct): - super().__init__(name, bases, dct) - if not hasattr(cls, "__args__"): - cls.__args__ = () - if cls.__args__: - base, = bases - cls.__origin__ = base - else: - cls._type_cache = weakref.WeakValueDictionary() - - def __getitem__(cls, arg_types): - if not isinstance(arg_types, tuple): - arg_types = (arg_types,) - assert not any(isvariadic(arg_type) for arg_type in arg_types), "nested variadic types not supported" - if arg_types not in cls._type_cache: - assert not get_args(cls), "cannot subscript a subscripted type {}".format(cls) - new_dct = cls.__dict__.copy() - new_dct.update({"__args__": arg_types}) - # type(cls) to handle GenericTypeMeta subclasses - cls._type_cache[arg_types] = type(cls)(cls.__name__, (cls,), new_dct) - return cls._type_cache[arg_types] - - @functools.lru_cache(maxsize=None) - def __subclasscheck__(cls, subcls): # issubclass(subcls, cls) - if cls is subcls: - return True - - if not isinstance(subcls, GenericTypeMeta): - return super(GenericTypeMeta, get_origin(cls)).__subclasscheck__(subcls) - - if not super(GenericTypeMeta, get_origin(cls)).__subclasscheck__(get_origin(subcls)): - return False - - if len(get_args(cls)) != len(get_args(subcls)): - return len(get_args(cls)) == 0 - - return all(pytypes.is_subtype(_type_to_typing(ps), _type_to_typing(pc)) - for ps, pc in zip(get_args(subcls), get_args(cls))) - - def __repr__(cls): - return get_origin(cls).__name__ + ( - "" if not get_args(cls) else - "[{}]".format(", ".join(repr(t) for t in get_args(cls)))) - - @lazy_property - def classname(cls): - return repr(cls) - - -class _PytypesSubclasser(GenericTypeMeta): - def __getitem__(cls, tp): - return tp if isinstance(tp, GenericTypeMeta) or isvariadic(tp) else super().__getitem__(tp) - - def __subclasscheck__(cls, subcls): - if isinstance(subcls, _PytypesSubclasser): - subcls = subcls.__args__[0] - return pytypes.is_subtype(subcls, cls.__args__[0]) - - -class typing_wrap(metaclass=_PytypesSubclasser): - pass From 24985dbdc45408f428b356a30e3e6ce4d9969a06 Mon Sep 17 00:00:00 2001 From: Eli Date: Thu, 21 Jan 2021 18:22:57 -0500 Subject: [PATCH 10/66] uncomment pattern --- funsor/gaussian.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/funsor/gaussian.py b/funsor/gaussian.py index f7a53ed08..8c74cf0aa 100644 --- a/funsor/gaussian.py +++ b/funsor/gaussian.py @@ -645,8 +645,8 @@ def eager_add_gaussian_gaussian(op, lhs, rhs): return Gaussian(info_vec, precision, inputs) -# @eager.register(Binary, SubOp, Gaussian, (Funsor, Align, Gaussian)) -# @eager.register(Binary, SubOp, (Funsor, Align, Delta), Gaussian) +@eager.register(Binary, SubOp, Gaussian, (Funsor, Align, Gaussian)) +@eager.register(Binary, SubOp, (Funsor, Align, Delta), Gaussian) def eager_sub(op, lhs, rhs): return lhs + -rhs From 6538ba879748478a27ce1a1f54906181a2d7316a Mon Sep 17 00:00:00 2001 From: Eli Date: Thu, 21 Jan 2021 18:36:09 -0500 Subject: [PATCH 11/66] move more things to funsor.typing --- funsor/registry.py | 19 +---------------- funsor/typing.py | 53 +++++++++++++++++++++++++++++++++++++++------- 2 files changed, 46 insertions(+), 26 deletions(-) diff --git a/funsor/registry.py b/funsor/registry.py index e4f362418..b430215dc 100644 --- a/funsor/registry.py +++ b/funsor/registry.py @@ -3,7 +3,7 @@ from collections import defaultdict -from funsor.typing import TypingDispatcher, _type_to_typing, deep_type, get_origin +from funsor.typing import TypingDispatcher, get_origin class PartialDispatcher(TypingDispatcher): @@ -14,23 +14,6 @@ def __init__(self, name, default=None): self.default = default if default is None else PartialDefault(default) super().__init__(name) - def partial_call(self, *args): - """ - Likde :meth:`__call__` but avoids calling ``func()``. - """ - types = tuple(map(deep_type, args)) - types = tuple(map(_type_to_typing, types)) - try: - func = self._cache[types] - except KeyError: - func = self.dispatch(*types) - if func is None: - raise NotImplementedError( - 'Could not find signature for %s: <%s>' % - (self.name, ', '.join(cls.__name__ for cls in types))) - self._cache[types] = func - return func - class PartialDefault: def __init__(self, default): diff --git a/funsor/typing.py b/funsor/typing.py index 215a4d989..b2f69e63f 100644 --- a/funsor/typing.py +++ b/funsor/typing.py @@ -10,7 +10,7 @@ from multipledispatch.conflict import supercedes from multipledispatch.dispatcher import Dispatcher -from multipledispatch.variadic import isvariadic +from multipledispatch.variadic import VariadicSignatureType, isvariadic def _type_to_typing(tp): @@ -51,11 +51,6 @@ def deep_isinstance(obj, cls): return pytypes.is_of_type(obj, cls) -def deep_supercedes(xs, ys): - return supercedes(tuple(typing_wrap[_type_to_typing(x)] for x in xs), - tuple(typing_wrap[_type_to_typing(y)] for y in ys)) - - class GenericTypeMeta(type): """ Metaclass to support subtyping with parameters for pattern matching, e.g. Number[int, int]. @@ -120,15 +115,57 @@ def __subclasscheck__(cls, subcls): class typing_wrap(metaclass=_PytypesSubclasser): + """ + Metaclass for overriding the runtime behavior of `typing` objects. + """ pass -class TypingDispatcher(Dispatcher): +def deep_supercedes(xs, ys): + """typing-compatible version of multipledispatch.conflict.supercedes""" + return supercedes(tuple(typing_wrap[_type_to_typing(x)] for x in xs), + tuple(typing_wrap[_type_to_typing(y)] for y in ys)) + +class DeepVariadicSignatureType(VariadicSignatureType): + pass # TODO + + +class Variadic(metaclass=DeepVariadicSignatureType): + """ + A version of multipledispatch.variadic.Variadic compatible with typing. + """ + pass # TODO + + +class TypingDispatcher(Dispatcher): + """ + A Dispatcher class designed for compatibility with the typing standard library. + """ def register(self, *types): types = tuple(typing_wrap[tp] for tp in map(_type_to_typing, types)) - if self.default: + if getattr(self, "default", None): objects = (typing_wrap[typing.Any],) * len(types) if objects != types and deep_supercedes(types, objects): super().register(*objects)(self.default) return super().register(*types) + + def partial_call(self, *args): + """ + Likde :meth:`__call__` but avoids calling ``func()``. + """ + types = tuple(map(deep_type, args)) + types = tuple(map(_type_to_typing, types)) + try: + func = self._cache[types] + except KeyError: + func = self.dispatch(*types) + if func is None: + raise NotImplementedError( + 'Could not find signature for %s: <%s>' % + (self.name, ', '.join(cls.__name__ for cls in types))) + self._cache[types] = func + return func + + def __call__(self, *args): + return self.partial_call(*args)(*args) From 07fc6b73d78a01eabe7c69bdcda3cb9f3391ff16 Mon Sep 17 00:00:00 2001 From: Eli Date: Thu, 21 Jan 2021 18:45:27 -0500 Subject: [PATCH 12/66] organize code in funsor.typing --- funsor/typing.py | 65 +++++++++++++++++++++++++++++++----------------- 1 file changed, 42 insertions(+), 23 deletions(-) diff --git a/funsor/typing.py b/funsor/typing.py index b2f69e63f..62640462d 100644 --- a/funsor/typing.py +++ b/funsor/typing.py @@ -13,6 +13,29 @@ from multipledispatch.variadic import VariadicSignatureType, isvariadic +################################# +# Runtime type-checking helpers +################################# + +def deep_isinstance(obj, cls): + """replaces isinstance()""" + return pytypes.is_of_type(obj, cls) + + +def deep_issubclass(subcls, cls): + """replaces issubclass()""" + return pytypes.is_subtype(subcls, cls) + + +def deep_type(obj): + """replaces type()""" + return pytypes.deep_type(obj) + + +############################################## +# Funsor-compatible typing introspection API +############################################## + def _type_to_typing(tp): if tp is object: tp = typing.Any @@ -21,35 +44,27 @@ def _type_to_typing(tp): return tp -def get_origin(tp): - if isinstance(tp, GenericTypeMeta): - return getattr(tp, "__origin__", tp) - return typing_extensions.get_origin(tp) - - def get_args(tp): if isinstance(tp, GenericTypeMeta): return getattr(tp, "__args__", tp) return typing_extensions.get_args(tp) +def get_origin(tp): + if isinstance(tp, GenericTypeMeta): + return getattr(tp, "__origin__", tp) + return typing_extensions.get_origin(tp) + + def get_type_hints(obj, globalns=None, localns=None, include_extras=False): if isinstance(obj, GenericTypeMeta) and hasattr(obj, "__annotations__"): return obj.__annotations__ return typing_extensions.get_type_hints(obj, globalns=globalns, localns=localns, include_extras=include_extras) -def deep_type(obj): - return pytypes.deep_type(obj) - - -def deep_issubclass(subcls, cls): - return pytypes.is_subtype(subcls, cls) - - -def deep_isinstance(obj, cls): - return pytypes.is_of_type(obj, cls) - +###################################################################### +# Metaclass for generating parametric types with Tuple-like variance +###################################################################### class GenericTypeMeta(type): """ @@ -104,17 +119,21 @@ def classname(cls): return repr(cls) -class _PytypesSubclasser(GenericTypeMeta): +############################################################## +# Tools and overrides for typing-compatible multipledispatch +############################################################## + +class _RuntimeSubclassCheckMeta(GenericTypeMeta): def __getitem__(cls, tp): return tp if isinstance(tp, GenericTypeMeta) or isvariadic(tp) else super().__getitem__(tp) def __subclasscheck__(cls, subcls): - if isinstance(subcls, _PytypesSubclasser): + if isinstance(subcls, _RuntimeSubclassCheckMeta): subcls = subcls.__args__[0] return deep_issubclass(subcls, cls.__args__[0]) -class typing_wrap(metaclass=_PytypesSubclasser): +class typing_wrap(metaclass=_RuntimeSubclassCheckMeta): """ Metaclass for overriding the runtime behavior of `typing` objects. """ @@ -128,14 +147,14 @@ def deep_supercedes(xs, ys): class DeepVariadicSignatureType(VariadicSignatureType): - pass # TODO + pass # TODO define __getitem__, possibly __eq__/__hash__? class Variadic(metaclass=DeepVariadicSignatureType): """ A version of multipledispatch.variadic.Variadic compatible with typing. """ - pass # TODO + pass # TODO is there anything else to do here? class TypingDispatcher(Dispatcher): @@ -144,7 +163,7 @@ class TypingDispatcher(Dispatcher): """ def register(self, *types): types = tuple(typing_wrap[tp] for tp in map(_type_to_typing, types)) - if getattr(self, "default", None): + if getattr(self, "default", None): # XXX should this class have default? objects = (typing_wrap[typing.Any],) * len(types) if objects != types and deep_supercedes(types, objects): super().register(*objects)(self.default) From b598204aec76a9223021d4f764d4074e9aa0a95d Mon Sep 17 00:00:00 2001 From: Eli Date: Thu, 21 Jan 2021 18:56:49 -0500 Subject: [PATCH 13/66] nit --- funsor/typing.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/funsor/typing.py b/funsor/typing.py index 62640462d..2afeb0984 100644 --- a/funsor/typing.py +++ b/funsor/typing.py @@ -57,8 +57,8 @@ def get_origin(tp): def get_type_hints(obj, globalns=None, localns=None, include_extras=False): - if isinstance(obj, GenericTypeMeta) and hasattr(obj, "__annotations__"): - return obj.__annotations__ + if isinstance(obj, GenericTypeMeta): + return getattr(obj, "__annotations__", {}) return typing_extensions.get_type_hints(obj, globalns=globalns, localns=localns, include_extras=include_extras) @@ -152,7 +152,7 @@ class DeepVariadicSignatureType(VariadicSignatureType): class Variadic(metaclass=DeepVariadicSignatureType): """ - A version of multipledispatch.variadic.Variadic compatible with typing. + A typing-compatible drop-in replacement for multipledispatch.variadic.Variadic. """ pass # TODO is there anything else to do here? From 7630c1dbbad47a3d547e29074cfcad37d11b3b66 Mon Sep 17 00:00:00 2001 From: Eli Date: Thu, 21 Jan 2021 19:02:44 -0500 Subject: [PATCH 14/66] change syntax of typing_wrap to callable --- funsor/typing.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/funsor/typing.py b/funsor/typing.py index 2afeb0984..bc0f2ff07 100644 --- a/funsor/typing.py +++ b/funsor/typing.py @@ -124,8 +124,8 @@ def classname(cls): ############################################################## class _RuntimeSubclassCheckMeta(GenericTypeMeta): - def __getitem__(cls, tp): - return tp if isinstance(tp, GenericTypeMeta) or isvariadic(tp) else super().__getitem__(tp) + def __call__(cls, tp): + return tp if isinstance(tp, GenericTypeMeta) or isvariadic(tp) else cls[tp] def __subclasscheck__(cls, subcls): if isinstance(subcls, _RuntimeSubclassCheckMeta): @@ -135,15 +135,15 @@ def __subclasscheck__(cls, subcls): class typing_wrap(metaclass=_RuntimeSubclassCheckMeta): """ - Metaclass for overriding the runtime behavior of `typing` objects. + Utility callable for overriding the runtime behavior of `typing` objects. """ pass def deep_supercedes(xs, ys): """typing-compatible version of multipledispatch.conflict.supercedes""" - return supercedes(tuple(typing_wrap[_type_to_typing(x)] for x in xs), - tuple(typing_wrap[_type_to_typing(y)] for y in ys)) + return supercedes(tuple(typing_wrap(_type_to_typing(x)) for x in xs), + tuple(typing_wrap(_type_to_typing(y)) for y in ys)) class DeepVariadicSignatureType(VariadicSignatureType): @@ -162,9 +162,9 @@ class TypingDispatcher(Dispatcher): A Dispatcher class designed for compatibility with the typing standard library. """ def register(self, *types): - types = tuple(typing_wrap[tp] for tp in map(_type_to_typing, types)) + types = tuple(map(typing_wrap, map(_type_to_typing, types)) if getattr(self, "default", None): # XXX should this class have default? - objects = (typing_wrap[typing.Any],) * len(types) + objects = (typing_wrap(typing.Any),) * len(types) if objects != types and deep_supercedes(types, objects): super().register(*objects)(self.default) return super().register(*types) From 09c2c8b314f836d397ca6cf3ae436c54057e7b07 Mon Sep 17 00:00:00 2001 From: Eli Date: Thu, 21 Jan 2021 23:27:54 -0500 Subject: [PATCH 15/66] remove memoize --- funsor/typing.py | 21 ++++++++++----------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/funsor/typing.py b/funsor/typing.py index bc0f2ff07..5b9ec7d6f 100644 --- a/funsor/typing.py +++ b/funsor/typing.py @@ -17,6 +17,14 @@ # Runtime type-checking helpers ################################# +def _type_to_typing(tp): + if tp is object: + tp = typing.Any + if isinstance(tp, tuple): + tp = typing.Union[tuple(map(_type_to_typing, tp))] + return tp + + def deep_isinstance(obj, cls): """replaces isinstance()""" return pytypes.is_of_type(obj, cls) @@ -36,14 +44,6 @@ def deep_type(obj): # Funsor-compatible typing introspection API ############################################## -def _type_to_typing(tp): - if tp is object: - tp = typing.Any - if isinstance(tp, tuple): - tp = typing.Union[tuple(map(_type_to_typing, tp))] - return tp - - def get_args(tp): if isinstance(tp, GenericTypeMeta): return getattr(tp, "__args__", tp) @@ -92,7 +92,6 @@ def __getitem__(cls, arg_types): cls._type_cache[arg_types] = type(cls)(cls.__name__, (cls,), new_dct) return cls._type_cache[arg_types] - @functools.lru_cache(maxsize=None) def __subclasscheck__(cls, subcls): # issubclass(subcls, cls) if cls is subcls: return True @@ -146,11 +145,11 @@ def deep_supercedes(xs, ys): tuple(typing_wrap(_type_to_typing(y)) for y in ys)) -class DeepVariadicSignatureType(VariadicSignatureType): +class _DeepVariadicSignatureType(VariadicSignatureType): pass # TODO define __getitem__, possibly __eq__/__hash__? -class Variadic(metaclass=DeepVariadicSignatureType): +class Variadic(metaclass=_DeepVariadicSignatureType): """ A typing-compatible drop-in replacement for multipledispatch.variadic.Variadic. """ From 4c8eb9e674f7c32554cdb928adae69ece23257b7 Mon Sep 17 00:00:00 2001 From: Eli Date: Wed, 27 Jan 2021 17:07:40 -0500 Subject: [PATCH 16/66] attempt to use custom issubclass --- funsor/typing.py | 55 ++++++++++++++++++++++++++++++++++++++++++------ 1 file changed, 49 insertions(+), 6 deletions(-) diff --git a/funsor/typing.py b/funsor/typing.py index 5b9ec7d6f..00b737040 100644 --- a/funsor/typing.py +++ b/funsor/typing.py @@ -27,12 +27,14 @@ def _type_to_typing(tp): def deep_isinstance(obj, cls): """replaces isinstance()""" - return pytypes.is_of_type(obj, cls) + # return pytypes.is_of_type(obj, cls) + return deep_issubclass(deep_type(obj), cls) def deep_issubclass(subcls, cls): """replaces issubclass()""" - return pytypes.is_subtype(subcls, cls) + # return pytypes.is_subtype(subcls, cls) + return _issubclass_tuple(subcls, cls) def deep_type(obj): @@ -40,20 +42,61 @@ def deep_type(obj): return pytypes.deep_type(obj) +def _issubclass_tuple(subcls, cls): + + if get_origin(cls) is typing.Union: + return any(_issubclass_tuple(subcls, arg) for arg in get_args(cls)) + + if get_origin(subcls) is typing.Union: # XXX is this right? + return any(_issubclass_tuple(arg, cls) for arg in get_args(subcls)) + + if cls is typing.Any: + return True + + if subcls is typing.Any: + return False + + if issubclass(get_origin(subcls), typing.Tuple) and \ + issubclass(get_origin(cls), typing.Tuple): + + if not issubclass(get_origin(subcls), get_origin(cls)): + return False + + if not get_args(cls): # cls is base Tuple + return True + + if get_args(cls)[-1] is Ellipsis: # cls variadic + if get_args(subcls)[-1] is Ellipsis: # both variadic + return _issubclass_tuple(get_args(subcls)[0], get_args(cls)[0]) + return all(_issubclass_tuple(a, get_args(cls)[0]) for a in get_args(subcls)) + + if get_args(subcls)[-1] is Ellipsis: # only subcls variadic + # issubclass(Tuple[A, ...], Tuple[X, Y]) == False + return False + + # neither variadic + return len(get_args(cls)) == len(get_args(subcls)) and \ + all(_issubclass_tuple(a, b) for a, b in zip(get_args(subcls), get_args(cls))) + + return issubclass(subcls, cls) + + ############################################## # Funsor-compatible typing introspection API ############################################## def get_args(tp): if isinstance(tp, GenericTypeMeta): - return getattr(tp, "__args__", tp) - return typing_extensions.get_args(tp) + return getattr(tp, "__args__", ()) + result = typing_extensions.get_args(tp) + return () if result is None else result def get_origin(tp): if isinstance(tp, GenericTypeMeta): return getattr(tp, "__origin__", tp) - return typing_extensions.get_origin(tp) + result = typing_extensions.get_origin(tp) + return tp if result is None else result def get_type_hints(obj, globalns=None, localns=None, include_extras=False): @@ -161,7 +204,7 @@ class TypingDispatcher(Dispatcher): A Dispatcher class designed for compatibility with the typing standard library. """ def register(self, *types): - types = tuple(map(typing_wrap, map(_type_to_typing, types)) + types = tuple(map(typing_wrap, map(_type_to_typing, types))) if getattr(self, "default", None): # XXX should this class have default? objects = (typing_wrap(typing.Any),) * len(types) if objects != types and deep_supercedes(types, objects): From 0e842c3ffb5d623d12b81f355683fe47291272a9 Mon Sep 17 00:00:00 2001 From: Eli Date: Sat, 30 Jan 2021 17:48:26 -0500 Subject: [PATCH 17/66] Add FrozenSet to deep_issubclass and start deep_type --- funsor/terms.py | 3 +- funsor/typing.py | 71 ++++++++++++++++++++++++++++++++-------------- test/test_terms.py | 4 +-- 3 files changed, 53 insertions(+), 25 deletions(-) diff --git a/funsor/terms.py b/funsor/terms.py index 334b444c2..c867228d8 100644 --- a/funsor/terms.py +++ b/funsor/terms.py @@ -17,7 +17,7 @@ from funsor.domains import Array, Bint, Domain, Product, Real, find_domain from funsor.interpreter import PatternMissingError, dispatched_interpretation, interpret from funsor.ops import AssociativeOp, GetitemOp, Op -from funsor.typing import GenericTypeMeta, _type_to_typing +from funsor.typing import GenericTypeMeta from funsor.util import getargspec, get_backend, lazy_property, pretty, quote @@ -213,7 +213,6 @@ def __getitem__(cls, arg_types): arg_types = (arg_types,) assert len(arg_types) == len(cls._ast_fields), \ "Must provide exactly one type per subexpression" - arg_types = tuple(map(_type_to_typing, arg_types)) return super().__getitem__(arg_types) def __call__(cls, *args, **kwargs): diff --git a/funsor/typing.py b/funsor/typing.py index 00b737040..628be4880 100644 --- a/funsor/typing.py +++ b/funsor/typing.py @@ -6,8 +6,6 @@ import typing_extensions import weakref -import pytypes - from multipledispatch.conflict import supercedes from multipledispatch.dispatcher import Dispatcher from multipledispatch.variadic import VariadicSignatureType, isvariadic @@ -17,14 +15,6 @@ # Runtime type-checking helpers ################################# -def _type_to_typing(tp): - if tp is object: - tp = typing.Any - if isinstance(tp, tuple): - tp = typing.Union[tuple(map(_type_to_typing, tp))] - return tp - - def deep_isinstance(obj, cls): """replaces isinstance()""" # return pytypes.is_of_type(obj, cls) @@ -34,21 +24,34 @@ def deep_isinstance(obj, cls): def deep_issubclass(subcls, cls): """replaces issubclass()""" # return pytypes.is_subtype(subcls, cls) - return _issubclass_tuple(subcls, cls) + return _deep_issubclass(subcls, cls) def deep_type(obj): """replaces type()""" - return pytypes.deep_type(obj) + # return pytypes.deep_type(obj) + return _deep_type(obj) + + +def _deep_type(obj): + + if isinstance(obj, tuple): + return typing.Tuple[tuple(map(deep_type, obj))] if obj else typing.Tuple + if isinstance(obj, frozenset): + return typing.FrozenSet[next(map(deep_type, obj))] if obj else typing.FrozenSet -def _issubclass_tuple(subcls, cls): + return type(obj) + + +@functools.lru_cache(maxsize=None) +def _deep_issubclass(subcls, cls): if get_origin(cls) is typing.Union: - return any(_issubclass_tuple(subcls, arg) for arg in get_args(cls)) + return any(_deep_issubclass(subcls, arg) for arg in get_args(cls)) - if get_origin(subcls) is typing.Union: # XXX is this right? - return any(_issubclass_tuple(arg, cls) for arg in get_args(subcls)) + if get_origin(subcls) is typing.Union: + return all(_deep_issubclass(arg, cls) for arg in get_args(subcls)) if cls is typing.Any: return True @@ -56,8 +59,21 @@ def _issubclass_tuple(subcls, cls): if subcls is typing.Any: return False - if issubclass(get_origin(subcls), typing.Tuple) and \ - issubclass(get_origin(cls), typing.Tuple): + if issubclass(get_origin(cls), typing.FrozenSet): + + if not issubclass(get_origin(subcls), get_origin(cls)): + return False + + if not get_args(cls): + return True + + if not get_args(subcls): + return get_args(cls)[0] is typing.Any + + return len(get_args(subcls)) == len(get_args(cls)) == 1 and \ + _deep_issubclass(get_args(subcls)[0], get_args(cls)[0]) + + if issubclass(get_origin(cls), typing.Tuple): if not issubclass(get_origin(subcls), get_origin(cls)): return False @@ -65,10 +81,13 @@ def _issubclass_tuple(subcls, cls): if not get_args(cls): # cls is base Tuple return True + if not get_args(subcls): + return get_args(cls)[0] is typing.Any + if get_args(cls)[-1] is Ellipsis: # cls variadic if get_args(subcls)[-1] is Ellipsis: # both variadic - return _issubclass_tuple(get_args(subcls)[0], get_args(cls)[0]) - return all(_issubclass_tuple(a, get_args(cls)[0]) for a in get_args(subcls)) + return _deep_issubclass(get_args(subcls)[0], get_args(cls)[0]) + return all(_deep_issubclass(a, get_args(cls)[0]) for a in get_args(subcls)) if get_args(subcls)[-1] is Ellipsis: # only subcls variadic # issubclass(Tuple[A, ...], Tuple[X, Y]) == False @@ -76,11 +95,20 @@ def _issubclass_tuple(subcls, cls): # neither variadic return len(get_args(cls)) == len(get_args(subcls)) and \ - all(_issubclass_tuple(a, b) for a, b in zip(get_args(subcls), get_args(cls))) + all(_deep_issubclass(a, b) for a, b in zip(get_args(subcls), get_args(cls))) return issubclass(subcls, cls) +@functools.lru_cache(maxsize=None) +def _type_to_typing(tp): + if tp is object: + tp = typing.Any + if isinstance(tp, tuple): + tp = typing.Union[tuple(map(_type_to_typing, tp))] + return tp + + ############################################## # Funsor-compatible typing introspection API ############################################## @@ -127,6 +155,7 @@ def __getitem__(cls, arg_types): if not isinstance(arg_types, tuple): arg_types = (arg_types,) assert not any(isvariadic(arg_type) for arg_type in arg_types), "nested variadic types not supported" + arg_types = tuple(map(_type_to_typing, arg_types)) if arg_types not in cls._type_cache: assert not get_args(cls), "cannot subscript a subscripted type {}".format(cls) new_dct = cls.__dict__.copy() diff --git a/test/test_terms.py b/test/test_terms.py index f36a71e92..bd8779d50 100644 --- a/test/test_terms.py +++ b/test/test_terms.py @@ -536,8 +536,6 @@ def test_align_simple(): # Unions ("Reduce[ops.AssociativeOp, (Number, Stack[str, (tuple, typing.Tuple[Number, Number])]), frozenset]", "Funsor"), ("Reduce[ops.AssociativeOp, (Number, Stack), frozenset]", "Reduce[ops.Op, Funsor, frozenset]"), - ("Reduce[ops.AssociativeOp, (Stack, Reduce[ops.AssociativeOp, (Number, Stack), frozenset]), frozenset]", - "Reduce[(ops.Op, ops.AssociativeOp), Stack, frozenset]"), ]) def test_parametric_subclass(subcls_expr, cls_expr): subcls = eval(subcls_expr) @@ -565,6 +563,8 @@ def test_parametric_subclass(subcls_expr, cls_expr): ("Reduce[ops.Op, Funsor, frozenset]", "Reduce[ops.AssociativeOp, (Number, Stack), frozenset]"), ("Reduce[(ops.Op, ops.AssociativeOp), Stack, frozenset]", "Reduce[ops.AssociativeOp, (Stack[str, tuple], Reduce[ops.AssociativeOp, (Cat, Stack), frozenset]), frozenset]"), + ("Reduce[ops.AssociativeOp, (Stack, Reduce[ops.AssociativeOp, (Number, Stack), frozenset]), frozenset]", + "Reduce[(ops.Op, ops.AssociativeOp), Stack, frozenset]"), ]) def test_not_parametric_subclass(subcls_expr, cls_expr): subcls = eval(subcls_expr) From 525de14a9d8219799232641bce38209c9b711abf Mon Sep 17 00:00:00 2001 From: Eli Date: Sat, 30 Jan 2021 17:55:50 -0500 Subject: [PATCH 18/66] use deep_type in reflect --- funsor/terms.py | 7 ++----- funsor/typing.py | 2 -- 2 files changed, 2 insertions(+), 7 deletions(-) diff --git a/funsor/terms.py b/funsor/terms.py index c867228d8..8f32ba0eb 100644 --- a/funsor/terms.py +++ b/funsor/terms.py @@ -17,7 +17,7 @@ from funsor.domains import Array, Bint, Domain, Product, Real, find_domain from funsor.interpreter import PatternMissingError, dispatched_interpretation, interpret from funsor.ops import AssociativeOp, GetitemOp, Op -from funsor.typing import GenericTypeMeta +from funsor.typing import GenericTypeMeta, deep_type from funsor.util import getargspec, get_backend, lazy_property, pretty, quote @@ -78,10 +78,7 @@ def reflect(cls, *args, **kwargs): if cache_key in cls._cons_cache: return cls._cons_cache[cache_key] - arg_types = tuple(typing.Tuple[tuple(map(type, arg))] - if (type(arg) is tuple and all(isinstance(a, Funsor) for a in arg)) - else typing.Tuple if (type(arg) is tuple and not arg) - else type(arg) for arg in args) + arg_types = tuple(map(deep_type, args)) cls_specific = (cls.__origin__ if cls.__args__ else cls)[arg_types] result = super(FunsorMeta, cls_specific).__call__(*args) result._ast_values = args diff --git a/funsor/typing.py b/funsor/typing.py index 628be4880..37f286ebf 100644 --- a/funsor/typing.py +++ b/funsor/typing.py @@ -104,8 +104,6 @@ def _deep_issubclass(subcls, cls): def _type_to_typing(tp): if tp is object: tp = typing.Any - if isinstance(tp, tuple): - tp = typing.Union[tuple(map(_type_to_typing, tp))] return tp From c4e19ec62459db019624bc70e24bf35101d4f3ec Mon Sep 17 00:00:00 2001 From: Eli Date: Sat, 30 Jan 2021 18:03:02 -0500 Subject: [PATCH 19/66] fix subclass tests --- test/test_terms.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/test/test_terms.py b/test/test_terms.py index bd8779d50..79793b70a 100644 --- a/test/test_terms.py +++ b/test/test_terms.py @@ -534,8 +534,8 @@ def test_align_simple(): ("Stack[str, typing.Tuple[Number, Number, Number]]", "Stack"), ("Stack[str, typing.Tuple[Number, Number, Number]]", "Stack[str, tuple]"), # Unions - ("Reduce[ops.AssociativeOp, (Number, Stack[str, (tuple, typing.Tuple[Number, Number])]), frozenset]", "Funsor"), - ("Reduce[ops.AssociativeOp, (Number, Stack), frozenset]", "Reduce[ops.Op, Funsor, frozenset]"), + ("Reduce[ops.AssociativeOp, typing.Union[Number, Stack[str, typing.Tuple[Number, Number]]], frozenset]", "Funsor"), + ("Reduce[ops.AssociativeOp, typing.Union[Number, Stack], frozenset]", "Reduce[ops.Op, Funsor, frozenset]"), ]) def test_parametric_subclass(subcls_expr, cls_expr): subcls = eval(subcls_expr) @@ -559,12 +559,12 @@ def test_parametric_subclass(subcls_expr, cls_expr): ("Stack[str, typing.Tuple[Number, Number]]", "Stack[str, typing.Tuple[Number, Reduce]]"), ("Stack[str, typing.Tuple[Number, Reduce]]", "Stack[str, typing.Tuple[Number, Number]]"), # Unions - ("Funsor", "Reduce[ops.AssociativeOp, (Number, Funsor), frozenset]"), - ("Reduce[ops.Op, Funsor, frozenset]", "Reduce[ops.AssociativeOp, (Number, Stack), frozenset]"), - ("Reduce[(ops.Op, ops.AssociativeOp), Stack, frozenset]", - "Reduce[ops.AssociativeOp, (Stack[str, tuple], Reduce[ops.AssociativeOp, (Cat, Stack), frozenset]), frozenset]"), - ("Reduce[ops.AssociativeOp, (Stack, Reduce[ops.AssociativeOp, (Number, Stack), frozenset]), frozenset]", - "Reduce[(ops.Op, ops.AssociativeOp), Stack, frozenset]"), + ("Funsor", "Reduce[ops.AssociativeOp, typing.Union[Number, Funsor], frozenset]"), + ("Reduce[ops.Op, Funsor, frozenset]", "Reduce[ops.AssociativeOp, typing.Union[Number, Stack], frozenset]"), + ("Reduce[typing.Union[ops.Op, ops.AssociativeOp], Stack, frozenset]", + "Reduce[ops.AssociativeOp, typing.Union[Stack[str, tuple], Reduce[ops.AssociativeOp, typing.Union[Cat, Stack], frozenset]], frozenset]"), # noqa: E501 + ("Reduce[ops.AssociativeOp, typing.Union[Stack, Reduce[ops.AssociativeOp, typing.Union[Number, Stack], frozenset]], frozenset]", # noqa: E501 + "Reduce[typing.Union[ops.Op, ops.AssociativeOp], Stack, frozenset]"), ]) def test_not_parametric_subclass(subcls_expr, cls_expr): subcls = eval(subcls_expr) From 2cf8aa9d6f371bb3c94c086edee7aac64a036d6c Mon Sep 17 00:00:00 2001 From: Eli Date: Sat, 30 Jan 2021 18:03:53 -0500 Subject: [PATCH 20/66] remove pytypes dependency --- setup.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/setup.py b/setup.py index 430108dfa..e8148b577 100644 --- a/setup.py +++ b/setup.py @@ -38,8 +38,6 @@ 'multipledispatch', 'numpy>=1.7', 'opt_einsum>=2.3.2', - # pin pytypes to master - 'git+https://github.com/Stewori/pytypes.git@master#egg=pytypes', ], extras_require={ 'torch': [ From 00fee9388f380c1bdd9a81ea2ef84450a721c826 Mon Sep 17 00:00:00 2001 From: Eli Date: Sat, 30 Jan 2021 18:31:13 -0500 Subject: [PATCH 21/66] revert nits --- funsor/interpreter.py | 1 - funsor/util.py | 1 - 2 files changed, 2 deletions(-) diff --git a/funsor/interpreter.py b/funsor/interpreter.py index f06810037..bc9e5b5d6 100644 --- a/funsor/interpreter.py +++ b/funsor/interpreter.py @@ -330,7 +330,6 @@ def dispatched_interpretation(fn): else: fn.register = registry.register fn.dispatch = registry.dispatch - fn._registry = registry return fn diff --git a/funsor/util.py b/funsor/util.py index 44234808b..d0692ebd8 100644 --- a/funsor/util.py +++ b/funsor/util.py @@ -7,7 +7,6 @@ import numpy as np - _FUNSOR_BACKEND = "numpy" _JAX_LOADED = False From c55eb8e5d439c2d29587a00003c517524170988a Mon Sep 17 00:00:00 2001 From: Eli Date: Sat, 30 Jan 2021 18:46:39 -0500 Subject: [PATCH 22/66] Revert Variadic removal --- funsor/cnf.py | 80 ++++++++++++++++++----------------- funsor/delta.py | 6 +-- funsor/distribution.py | 15 +++---- funsor/integrate.py | 9 ++-- funsor/jax/distributions.py | 10 ++--- funsor/joint.py | 21 ++++++--- funsor/optimizer.py | 20 ++++++--- funsor/tensor.py | 15 +++---- funsor/terms.py | 26 +++++++++--- funsor/torch/distributions.py | 10 ++--- 10 files changed, 117 insertions(+), 95 deletions(-) diff --git a/funsor/cnf.py b/funsor/cnf.py index dfddd0099..71a9cc59c 100644 --- a/funsor/cnf.py +++ b/funsor/cnf.py @@ -8,6 +8,7 @@ from typing import Tuple, Union import opt_einsum +from multipledispatch.variadic import Variadic import funsor import funsor.ops as ops @@ -15,14 +16,13 @@ from funsor.delta import Delta from funsor.domains import find_domain from funsor.gaussian import Gaussian -from funsor.interpreter import interpretation +from funsor.interpreter import interpretation, recursion_reinterpret from funsor.ops import DISTRIBUTIVE_OPS, AssociativeOp, NullOp, nullop from funsor.tensor import Tensor from funsor.terms import ( Align, Binary, Funsor, - FunsorMeta, Number, Reduce, Subs, @@ -36,15 +36,7 @@ from funsor.util import broadcast_shape, get_backend, quote -class ContractionMeta(FunsorMeta): - - def __call__(self, red_op, bin_op, reduced_vars, *terms): - if len(terms) == 1 and isinstance(terms[0], tuple): - terms, = terms - return super().__call__(red_op, bin_op, reduced_vars, tuple(terms)) - - -class Contraction(Funsor, metaclass=ContractionMeta): +class Contraction(Funsor): """ Declarative representation of a finitary sum-product operation. @@ -185,7 +177,17 @@ def _(arg, indent, out): out[-1] = i, line + ")" -@eager.register(Contraction, AssociativeOp, AssociativeOp, frozenset, Tuple[Funsor, ...]) +@recursion_reinterpret.register(Contraction) +def recursion_reinterpret_contraction(x): + return type(x)(*map(recursion_reinterpret, (x.red_op, x.bin_op, x.reduced_vars) + x.terms)) + + +@eager.register(Contraction, AssociativeOp, AssociativeOp, frozenset, Variadic[Funsor]) +def eager_contraction_generic_to_tuple(red_op, bin_op, reduced_vars, *terms): + return eager(Contraction, red_op, bin_op, reduced_vars, terms) + + +@eager.register(Contraction, AssociativeOp, AssociativeOp, frozenset, tuple) def eager_contraction_generic_recursive(red_op, bin_op, reduced_vars, terms): # Count the number of terms in which each variable is reduced. counts = Counter() @@ -226,16 +228,15 @@ def eager_contraction_generic_recursive(red_op, bin_op, reduced_vars, terms): return None -@eager.register(Contraction, AssociativeOp, AssociativeOp, frozenset, Tuple[Funsor]) -def eager_contraction_to_reduce(red_op, bin_op, reduced_vars, terms): - term, = terms +@eager.register(Contraction, AssociativeOp, AssociativeOp, frozenset, Funsor) +def eager_contraction_to_reduce(red_op, bin_op, reduced_vars, term): args = red_op, term, reduced_vars return eager.dispatch(Reduce, *args)(*args) -@eager.register(Contraction, AssociativeOp, AssociativeOp, frozenset, Tuple[Funsor, Funsor]) -def eager_contraction_to_binary(red_op, bin_op, reduced_vars, terms): - lhs, rhs = terms +@eager.register(Contraction, AssociativeOp, AssociativeOp, frozenset, Funsor, Funsor) +def eager_contraction_to_binary(red_op, bin_op, reduced_vars, lhs, rhs): + if not reduced_vars.issubset(lhs.input_vars & rhs.input_vars): args = red_op, bin_op, reduced_vars, (lhs, rhs) result = eager.dispatch(Contraction, *args)(*args) @@ -250,16 +251,16 @@ def eager_contraction_to_binary(red_op, bin_op, reduced_vars, terms): return result -@eager.register(Contraction, ops.AddOp, ops.MulOp, frozenset, Tuple[Tensor, Tensor]) -def eager_contraction_tensor(red_op, bin_op, reduced_vars, terms): +@eager.register(Contraction, ops.AddOp, ops.MulOp, frozenset, Tensor, Tensor) +def eager_contraction_tensor(red_op, bin_op, reduced_vars, *terms): if not all(term.dtype == "real" for term in terms): raise NotImplementedError('TODO') backend = BACKEND_TO_EINSUM_BACKEND[get_backend()] return _eager_contract_tensors(reduced_vars, terms, backend=backend) -@eager.register(Contraction, ops.LogaddexpOp, ops.AddOp, frozenset, Tuple[Tensor, Tensor]) -def eager_contraction_tensor(red_op, bin_op, reduced_vars, terms): +@eager.register(Contraction, ops.LogaddexpOp, ops.AddOp, frozenset, Tensor, Tensor) +def eager_contraction_tensor(red_op, bin_op, reduced_vars, *terms): if not all(term.dtype == "real" for term in terms): raise NotImplementedError('TODO') backend = BACKEND_TO_LOGSUMEXP_BACKEND[get_backend()] @@ -306,9 +307,8 @@ def _eager_contract_tensors(reduced_vars, terms, backend): # Pyro's gaussian_tensordot() here. Until then we must eagerly add the # possibly-rank-deficient terms before reducing to avoid Cholesky errors. @eager.register(Contraction, ops.LogaddexpOp, ops.AddOp, frozenset, - Tuple[GaussianMixture, GaussianMixture]) -def eager_contraction_gaussian(red_op, bin_op, reduced_vars, terms): - x, y = terms + GaussianMixture, GaussianMixture) +def eager_contraction_gaussian(red_op, bin_op, reduced_vars, x, y): return (x + y).reduce(red_op, reduced_vars) @@ -324,11 +324,11 @@ def _(fn): ########################################## ORDERING = {Delta: 1, Number: 2, Tensor: 3, Gaussian: 4} -GROUND_TERMS = Union[Delta, Number, Tensor, Gaussian] +GROUND_TERMS = tuple(ORDERING) -@normalize.register(Contraction, AssociativeOp, ops.AddOp, frozenset, Tuple[GROUND_TERMS, GROUND_TERMS]) -def normalize_contraction_commutative_canonical_order(red_op, bin_op, reduced_vars, terms): +@normalize.register(Contraction, AssociativeOp, ops.AddOp, frozenset, GROUND_TERMS, GROUND_TERMS) +def normalize_contraction_commutative_canonical_order(red_op, bin_op, reduced_vars, *terms): # when bin_op is commutative, put terms into a canonical order for pattern matching new_terms = tuple( v for i, v in sorted(enumerate(terms), @@ -336,31 +336,33 @@ def normalize_contraction_commutative_canonical_order(red_op, bin_op, reduced_va ) if any(v is not vv for v, vv in zip(terms, new_terms)): return Contraction(red_op, bin_op, reduced_vars, *new_terms) - return None # normalize(Contraction, red_op, bin_op, reduced_vars, new_terms) + return normalize(Contraction, red_op, bin_op, reduced_vars, new_terms) -@normalize.register(Contraction, AssociativeOp, ops.AddOp, frozenset, Tuple[GaussianMixture, GROUND_TERMS]) -def normalize_contraction_commute_joint(red_op, bin_op, reduced_vars, terms): - mixture, other = terms +@normalize.register(Contraction, AssociativeOp, ops.AddOp, frozenset, GaussianMixture, GROUND_TERMS) +def normalize_contraction_commute_joint(red_op, bin_op, reduced_vars, mixture, other): return Contraction(mixture.red_op if red_op is nullop else red_op, bin_op, reduced_vars | mixture.reduced_vars, *(mixture.terms + (other,))) -@normalize.register(Contraction, AssociativeOp, ops.AddOp, frozenset, Tuple[GROUND_TERMS, GaussianMixture]) -def normalize_contraction_commute_joint(red_op, bin_op, reduced_vars, terms): - other, mixture = terms +@normalize.register(Contraction, AssociativeOp, ops.AddOp, frozenset, GROUND_TERMS, GaussianMixture) +def normalize_contraction_commute_joint(red_op, bin_op, reduced_vars, other, mixture): return Contraction(mixture.red_op if red_op is nullop else red_op, bin_op, reduced_vars | mixture.reduced_vars, *(mixture.terms + (other,))) -@normalize.register(Contraction, NullOp, NullOp, frozenset, Tuple[Funsor]) -def normalize_trivial(red_op, bin_op, reduced_vars, terms): - term, = terms +@normalize.register(Contraction, AssociativeOp, AssociativeOp, frozenset, Variadic[Funsor]) +def normalize_contraction_generic_args(red_op, bin_op, reduced_vars, *terms): + return normalize(Contraction, red_op, bin_op, reduced_vars, tuple(terms)) + + +@normalize.register(Contraction, NullOp, NullOp, frozenset, Funsor) +def normalize_trivial(red_op, bin_op, reduced_vars, term): assert not reduced_vars return term -@normalize.register(Contraction, AssociativeOp, AssociativeOp, frozenset, Tuple[Funsor, ...]) +@normalize.register(Contraction, AssociativeOp, AssociativeOp, frozenset, tuple) def normalize_contraction_generic_tuple(red_op, bin_op, reduced_vars, terms): if not reduced_vars and red_op is not nullop: diff --git a/funsor/delta.py b/funsor/delta.py index eec0f0339..1a37d3724 100644 --- a/funsor/delta.py +++ b/funsor/delta.py @@ -187,8 +187,7 @@ def eager_add_multidelta(op, lhs, rhs): return Delta(lhs.terms + rhs.terms) -@eager.register(Binary, (AddOp, SubOp), Delta, Align) -@eager.register(Binary, (AddOp, SubOp), Delta, Funsor) +@eager.register(Binary, (AddOp, SubOp), Delta, (Funsor, Align)) def eager_add_delta_funsor(op, lhs, rhs): if lhs.fresh.intersection(rhs.inputs): rhs = rhs(**{name: point for name, (point, log_density) in lhs.terms if name in rhs.inputs}) @@ -197,8 +196,7 @@ def eager_add_delta_funsor(op, lhs, rhs): return None # defer to default implementation -@eager.register(Binary, AddOp, Align, Delta) -@eager.register(Binary, AddOp, Funsor, Delta) +@eager.register(Binary, AddOp, (Funsor, Align), Delta) def eager_add_funsor_delta(op, lhs, rhs): if rhs.fresh.intersection(lhs.inputs): lhs = lhs(**{name: point for name, (point, log_density) in rhs.terms if name in lhs.inputs}) diff --git a/funsor/distribution.py b/funsor/distribution.py index 3789c806a..81af49460 100644 --- a/funsor/distribution.py +++ b/funsor/distribution.py @@ -718,15 +718,13 @@ def eager_mvn(loc, scale_tril, value): return gaussian(**{var: value - loc}) -def eager_beta_bernoulli(red_op, bin_op, reduced_vars, xy): - x, y = xy +def eager_beta_bernoulli(red_op, bin_op, reduced_vars, x, y): backend_dist = import_module(BACKEND_TO_DISTRIBUTIONS_BACKEND[get_backend()]) return eager_dirichlet_multinomial(red_op, bin_op, reduced_vars, x, backend_dist.Binomial(total_count=1, probs=y.probs, value=y.value)) -def eager_dirichlet_categorical(red_op, bin_op, reduced_vars, xy): - x, y = xy +def eager_dirichlet_categorical(red_op, bin_op, reduced_vars, x, y): dirichlet_reduction = x.input_vars & reduced_vars if dirichlet_reduction: backend_dist = import_module(BACKEND_TO_DISTRIBUTIONS_BACKEND[get_backend()]) @@ -738,8 +736,7 @@ def eager_dirichlet_categorical(red_op, bin_op, reduced_vars, xy): return eager(Contraction, red_op, bin_op, reduced_vars, (x, y)) -def eager_dirichlet_multinomial(red_op, bin_op, reduced_vars, xy): - x, y = xy +def eager_dirichlet_multinomial(red_op, bin_op, reduced_vars, x, y): dirichlet_reduction = x.input_vars & reduced_vars if dirichlet_reduction: backend_dist = import_module(BACKEND_TO_DISTRIBUTIONS_BACKEND[get_backend()]) @@ -772,8 +769,7 @@ def _log_beta(x, y): return ops.lgamma(x) + ops.lgamma(y) - ops.lgamma(x + y) -def eager_gamma_gamma(red_op, bin_op, reduced_vars, xy): - x, y = xy +def eager_gamma_gamma(red_op, bin_op, reduced_vars, x, y): gamma_reduction = x.input_vars & reduced_vars if gamma_reduction: unnormalized = (y.concentration - 1) * ops.log(y.value) \ @@ -784,8 +780,7 @@ def eager_gamma_gamma(red_op, bin_op, reduced_vars, xy): return eager(Contraction, red_op, bin_op, reduced_vars, (x, y)) -def eager_gamma_poisson(red_op, bin_op, reduced_vars, xy): - x, y = xy +def eager_gamma_poisson(red_op, bin_op, reduced_vars, x, y): gamma_reduction = x.input_vars & reduced_vars if gamma_reduction: backend_dist = import_module(BACKEND_TO_DISTRIBUTIONS_BACKEND[get_backend()]) diff --git a/funsor/integrate.py b/funsor/integrate.py index 7c9623d10..3dce6a5f9 100644 --- a/funsor/integrate.py +++ b/funsor/integrate.py @@ -2,7 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 from collections import OrderedDict -from typing import Tuple, Union +from typing import Union import funsor.ops as ops from funsor.cnf import Contraction, GaussianMixture @@ -91,10 +91,9 @@ def normalize_integrate_contraction(log_measure, integrand, reduced_vars): @eager.register(Contraction, ops.AddOp, ops.MulOp, frozenset, - Tuple[Unary[ops.ExpOp, Union[GaussianMixture, Delta, Gaussian, Number, Tensor]], - Union[Variable, Delta, Gaussian, Number, Tensor, GaussianMixture]]) -def eager_contraction_binary_to_integrate(red_op, bin_op, reduced_vars, terms): - lhs, rhs = terms + Unary[ops.ExpOp, Union[GaussianMixture, Delta, Gaussian, Number, Tensor]], + (Variable, Delta, Gaussian, Number, Tensor, GaussianMixture)) +def eager_contraction_binary_to_integrate(red_op, bin_op, reduced_vars, lhs, rhs): reduced_names = frozenset(v.name for v in reduced_vars) if not (reduced_names.issubset(lhs.inputs) and reduced_names.issubset(rhs.inputs)): args = red_op, bin_op, reduced_vars, (lhs, rhs) diff --git a/funsor/jax/distributions.py b/funsor/jax/distributions.py index 94e37afb2..15f685304 100644 --- a/funsor/jax/distributions.py +++ b/funsor/jax/distributions.py @@ -278,15 +278,15 @@ def deltadist_to_funsor(pyro_dist, output=None, dim_to_name=None): eager.register(Delta, Variable, Variable, Variable)(eager_delta_variable_variable) # noqa: F821 eager.register(Normal, Funsor, Tensor, Funsor)(eager_normal) # noqa: F821 eager.register(MultivariateNormal, Funsor, Tensor, Funsor)(eager_mvn) # noqa: F821 -eager.register(Contraction, ops.LogaddexpOp, ops.AddOp, frozenset, Tuple[Dirichlet, BernoulliProbs])( # noqa: F821 +eager.register(Contraction, ops.LogaddexpOp, ops.AddOp, frozenset, Dirichlet, BernoulliProbs)( # noqa: F821 eager_beta_bernoulli) -eager.register(Contraction, ops.LogaddexpOp, ops.AddOp, frozenset, Tuple[Dirichlet, Categorical])( # noqa: F821 +eager.register(Contraction, ops.LogaddexpOp, ops.AddOp, frozenset, Dirichlet, Categorical)( # noqa: F821 eager_dirichlet_categorical) -eager.register(Contraction, ops.LogaddexpOp, ops.AddOp, frozenset, Tuple[Dirichlet, Multinomial])( # noqa: F821 +eager.register(Contraction, ops.LogaddexpOp, ops.AddOp, frozenset, Dirichlet, Multinomial)( # noqa: F821 eager_dirichlet_multinomial) -eager.register(Contraction, ops.LogaddexpOp, ops.AddOp, frozenset, Tuple[Gamma, Gamma])( # noqa: F821 +eager.register(Contraction, ops.LogaddexpOp, ops.AddOp, frozenset, Gamma, Gamma)( # noqa: F821 eager_gamma_gamma) -eager.register(Contraction, ops.LogaddexpOp, ops.AddOp, frozenset, Tuple[Gamma, Poisson])( # noqa: F821 +eager.register(Contraction, ops.LogaddexpOp, ops.AddOp, frozenset, Gamma, Poisson)( # noqa: F821 eager_gamma_poisson) if hasattr(dist, "DirichletMultinomial"): eager.register(Binary, ops.SubOp, JointDirichletMultinomial, DirichletMultinomial)( # noqa: F821 diff --git a/funsor/joint.py b/funsor/joint.py index 2c085ddc5..09fac9084 100644 --- a/funsor/joint.py +++ b/funsor/joint.py @@ -6,17 +6,21 @@ from functools import reduce from typing import Tuple, Union +from multipledispatch import dispatch +from multipledispatch.variadic import Variadic + import funsor.ops as ops from funsor.cnf import Contraction, GaussianMixture from funsor.delta import Delta from funsor.domains import Bint from funsor.gaussian import Gaussian, align_gaussian +from funsor.ops import AssociativeOp from funsor.tensor import Tensor, align_tensor -from funsor.terms import Cat, Funsor, Independent, Number, Reduce, Unary, eager, moment_matching, normalize +from funsor.terms import Funsor, Independent, Number, Reduce, Unary, eager, moment_matching, normalize -@eager.register(Cat, str, Tuple[Union[Gaussian, GaussianMixture], ...], str) -def eager_cat_gaussian(name, parts, part_name): +@dispatch(str, str, Variadic[(Gaussian, GaussianMixture)]) +def eager_cat_homogeneous(name, part_name, *parts): assert parts output = parts[0].output inputs = OrderedDict([(part_name, None)]) @@ -71,9 +75,14 @@ def eager_cat_gaussian(name, parts, part_name): # patterns for moment-matching ################################# -@moment_matching.register(Contraction, ops.LogaddexpOp, ops.AddOp, frozenset, Tuple[Union[Number, Tensor], Gaussian]) -def moment_matching_contract_joint(red_op, bin_op, reduced_vars, terms): - discrete, gaussian = terms +@moment_matching.register(Contraction, AssociativeOp, AssociativeOp, frozenset, Variadic[object]) +def moment_matching_contract_default(*args): + return None + + +@moment_matching.register(Contraction, ops.LogaddexpOp, ops.AddOp, frozenset, (Number, Tensor), Gaussian) +def moment_matching_contract_joint(red_op, bin_op, reduced_vars, discrete, gaussian): + approx_vars = frozenset(v for v in reduced_vars if v.name in gaussian.inputs if v.dtype != 'real') diff --git a/funsor/optimizer.py b/funsor/optimizer.py index 39f57f34a..9c63a8606 100644 --- a/funsor/optimizer.py +++ b/funsor/optimizer.py @@ -2,8 +2,8 @@ # SPDX-License-Identifier: Apache-2.0 import collections -from typing import Tuple +from multipledispatch.variadic import Variadic from opt_einsum.paths import greedy import funsor.interpreter as interpreter @@ -22,7 +22,7 @@ def unfold(cls, *args): return result -@unfold.register(Contraction, AssociativeOp, AssociativeOp, frozenset, Tuple[Funsor, ...]) +@unfold.register(Contraction, AssociativeOp, AssociativeOp, frozenset, tuple) def unfold_contraction_generic_tuple(red_op, bin_op, reduced_vars, terms): for i, v in enumerate(terms): @@ -50,6 +50,10 @@ def unfold_contraction_generic_tuple(red_op, bin_op, reduced_vars, terms): return None +unfold.register(Contraction, AssociativeOp, AssociativeOp, frozenset, Variadic[Funsor])( + lambda r, b, v, *ts: unfold(Contraction, r, b, v, tuple(ts))) + + @interpreter.dispatched_interpretation def optimize(cls, *args): result = optimize.dispatch(cls, *args)(*args) @@ -62,13 +66,17 @@ def optimize(cls, *args): REAL_SIZE = 3 # the "size" of a real-valued dimension passed to the path optimizer -@optimize.register(Contraction, AssociativeOp, AssociativeOp, frozenset, Tuple[Funsor, Funsor]) -@optimize.register(Contraction, AssociativeOp, AssociativeOp, frozenset, Tuple[Funsor]) -def eager_contract_base(red_op, bin_op, reduced_vars, terms): +optimize.register(Contraction, AssociativeOp, AssociativeOp, frozenset, Variadic[Funsor])( + lambda r, b, v, *ts: optimize(Contraction, r, b, v, tuple(ts))) + + +@optimize.register(Contraction, AssociativeOp, AssociativeOp, frozenset, Funsor, Funsor) +@optimize.register(Contraction, AssociativeOp, AssociativeOp, frozenset, Funsor) +def eager_contract_base(red_op, bin_op, reduced_vars, *terms): return None -@optimize.register(Contraction, AssociativeOp, AssociativeOp, frozenset, Tuple[Funsor, ...]) +@optimize.register(Contraction, AssociativeOp, AssociativeOp, frozenset, tuple) def optimize_contract_finitary_funsor(red_op, bin_op, reduced_vars, terms): if red_op is nullop or bin_op is nullop or not (red_op, bin_op) in DISTRIBUTIVE_OPS: diff --git a/funsor/tensor.py b/funsor/tensor.py index 1152ad15f..e4ec652fa 100644 --- a/funsor/tensor.py +++ b/funsor/tensor.py @@ -11,6 +11,8 @@ import numpy as np import opt_einsum +from multipledispatch import dispatch +from multipledispatch.variadic import Variadic import funsor import funsor.ops as ops @@ -19,13 +21,11 @@ from funsor.ops import GetitemOp, MatmulOp, Op, ReshapeOp from funsor.terms import ( Binary, - Cat, Funsor, FunsorMeta, Lambda, Number, Slice, - Stack, Tuple, Unary, Variable, @@ -712,8 +712,8 @@ def eager_lambda(var, expr): return Tensor(data, inputs, expr.dtype) -@eager.register(Stack, str, typing.Tuple[Tensor, ...]) -def eager_stack_tensors(name, parts): +@dispatch(str, Variadic[Tensor]) +def eager_stack_homogeneous(name, *parts): assert parts output = parts[0].output part_inputs = OrderedDict() @@ -730,12 +730,9 @@ def eager_stack_tensors(name, parts): return Tensor(data, inputs, dtype=output.dtype) -@eager.register(Cat, str, typing.Tuple[Tensor, ...], str) -def eager_cat_tensors(name, parts, part_name): +@dispatch(str, str, Variadic[Tensor]) +def eager_cat_homogeneous(name, part_name, *parts): assert parts - if len(parts) == 1: - return parts[0](**{part_name: name}) - output = parts[0].output inputs = OrderedDict([(part_name, None)]) for part in parts: diff --git a/funsor/terms.py b/funsor/terms.py index 8f32ba0eb..f6bd22f04 100644 --- a/funsor/terms.py +++ b/funsor/terms.py @@ -12,6 +12,9 @@ from functools import reduce, singledispatch from weakref import WeakValueDictionary +from multipledispatch import dispatch +from multipledispatch.variadic import Variadic, isvariadic + import funsor.interpreter as interpreter import funsor.ops as ops from funsor.domains import Array, Bint, Domain, Product, Real, find_domain @@ -1276,9 +1279,14 @@ def eager_reduce(self, op, reduced_vars): return Stack(self.name, parts) -@eager.register(Stack, str, typing.Tuple[Funsor, ...]) +@eager.register(Stack, str, tuple) def eager_stack(name, parts): - return None + return eager_stack_homogeneous(name, *parts) + + +@dispatch(str, Variadic[Funsor]) +def eager_stack_homogeneous(name, *parts): + return None # defer to default implementation class CatMeta(FunsorMeta): @@ -1370,9 +1378,16 @@ def eager_subs(self, subs): .format(type(value))) -@eager.register(Cat, str, typing.Tuple[Funsor], str) +@eager.register(Cat, str, tuple, str) def eager_cat(name, parts, part_name): - return parts[0](**{part_name: name}) + if len(parts) == 1: + return parts[0](**{part_name: name}) + return eager_cat_homogeneous(name, part_name, *parts) + + +@dispatch(str, str, Variadic[Funsor]) +def eager_cat_homogeneous(name, part_name, *parts): + return None # defer to default implementation class Lambda(Funsor): @@ -1405,8 +1420,7 @@ def _alpha_convert(self, alpha_subs): return super()._alpha_convert(alpha_subs) -@eager.register(Binary, GetitemOp, Lambda, Align) -@eager.register(Binary, GetitemOp, Lambda, Funsor) +@eager.register(Binary, GetitemOp, Lambda, (Funsor, Align)) def eager_getitem_lambda(op, lhs, rhs): if op.offset == 0: return Subs(lhs.expr, ((lhs.var.name, rhs),)) diff --git a/funsor/torch/distributions.py b/funsor/torch/distributions.py index 84d6f8e5b..6fe1b3186 100644 --- a/funsor/torch/distributions.py +++ b/funsor/torch/distributions.py @@ -345,15 +345,15 @@ def deltadist_to_funsor(pyro_dist, output=None, dim_to_name=None): eager.register(Delta, Variable, Variable, Variable)(eager_delta_variable_variable) # noqa: F821 eager.register(Normal, Funsor, Tensor, Funsor)(eager_normal) # noqa: F821 eager.register(MultivariateNormal, Funsor, Tensor, Funsor)(eager_mvn) # noqa: F821 -eager.register(Contraction, ops.LogaddexpOp, ops.AddOp, frozenset, Tuple[Dirichlet, BernoulliProbs])( # noqa: F821 +eager.register(Contraction, ops.LogaddexpOp, ops.AddOp, frozenset, Dirichlet, BernoulliProbs)( # noqa: F821 eager_beta_bernoulli) -eager.register(Contraction, ops.LogaddexpOp, ops.AddOp, frozenset, Tuple[Dirichlet, Categorical])( # noqa: F821 +eager.register(Contraction, ops.LogaddexpOp, ops.AddOp, frozenset, Dirichlet, Categorical)( # noqa: F821 eager_dirichlet_categorical) -eager.register(Contraction, ops.LogaddexpOp, ops.AddOp, frozenset, Tuple[Dirichlet, Multinomial])( # noqa: F821 +eager.register(Contraction, ops.LogaddexpOp, ops.AddOp, frozenset, Dirichlet, Multinomial)( # noqa: F821 eager_dirichlet_multinomial) -eager.register(Contraction, ops.LogaddexpOp, ops.AddOp, frozenset, Tuple[Gamma, Gamma])( # noqa: F821 +eager.register(Contraction, ops.LogaddexpOp, ops.AddOp, frozenset, Gamma, Gamma)( # noqa: F821 eager_gamma_gamma) -eager.register(Contraction, ops.LogaddexpOp, ops.AddOp, frozenset, Tuple[Gamma, Poisson])( # noqa: F821 +eager.register(Contraction, ops.LogaddexpOp, ops.AddOp, frozenset, Gamma, Poisson)( # noqa: F821 eager_gamma_poisson) eager.register(Binary, ops.SubOp, JointDirichletMultinomial, DirichletMultinomial)( # noqa: F821 eager_dirichlet_posterior) From 1c950bf840985751f6cea69acad8811666d97752 Mon Sep 17 00:00:00 2001 From: Eli Date: Sat, 30 Jan 2021 18:48:01 -0500 Subject: [PATCH 23/66] lint --- funsor/terms.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/funsor/terms.py b/funsor/terms.py index f6bd22f04..704e62129 100644 --- a/funsor/terms.py +++ b/funsor/terms.py @@ -13,7 +13,7 @@ from weakref import WeakValueDictionary from multipledispatch import dispatch -from multipledispatch.variadic import Variadic, isvariadic +from multipledispatch.variadic import Variadic import funsor.interpreter as interpreter import funsor.ops as ops From 04309448d728247762a9a33d84335d92438278cf Mon Sep 17 00:00:00 2001 From: Eli Date: Sat, 30 Jan 2021 18:52:43 -0500 Subject: [PATCH 24/66] add typing_extensions dependency --- setup.py | 1 + 1 file changed, 1 insertion(+) diff --git a/setup.py b/setup.py index 0b806bfd8..2b5e05659 100644 --- a/setup.py +++ b/setup.py @@ -37,6 +37,7 @@ 'multipledispatch', 'numpy>=1.7', 'opt_einsum>=2.3.2', + 'typing_extensions', ], extras_require={ 'torch': [ From 22b180b8e4c112910aa071511b55462c8fc2a964 Mon Sep 17 00:00:00 2001 From: Eli Date: Sat, 30 Jan 2021 19:22:04 -0500 Subject: [PATCH 25/66] implement add instead of register in typingdispatcher and add Variadic --- funsor/typing.py | 41 +++++++++++++++++++++++++++++++---------- 1 file changed, 31 insertions(+), 10 deletions(-) diff --git a/funsor/typing.py b/funsor/typing.py index 37f286ebf..68d822441 100644 --- a/funsor/typing.py +++ b/funsor/typing.py @@ -7,8 +7,9 @@ import weakref from multipledispatch.conflict import supercedes -from multipledispatch.dispatcher import Dispatcher -from multipledispatch.variadic import VariadicSignatureType, isvariadic +from multipledispatch.dispatcher import Dispatcher, expand_tuples +from multipledispatch.variadic import isvariadic +from multipledispatch.variadic import Variadic as _OrigVariadic ################################# @@ -215,8 +216,12 @@ def deep_supercedes(xs, ys): tuple(typing_wrap(_type_to_typing(y)) for y in ys)) -class _DeepVariadicSignatureType(VariadicSignatureType): - pass # TODO define __getitem__, possibly __eq__/__hash__? +class _DeepVariadicSignatureType(type): + + def __getitem__(cls, key): + if not isinstance(key, tuple): + key = (key,) + return _OrigVariadic[tuple(map(typing_wrap, map(_type_to_typing, key)))] class Variadic(metaclass=_DeepVariadicSignatureType): @@ -230,13 +235,29 @@ class TypingDispatcher(Dispatcher): """ A Dispatcher class designed for compatibility with the typing standard library. """ - def register(self, *types): - types = tuple(map(typing_wrap, map(_type_to_typing, types))) + def add(self, signature, func): + + # Handle annotations + if not signature: + annotations = self.get_func_annotations(func) + if annotations: + signature = annotations + + # Handle union types + if any(isinstance(typ, tuple) for typ in signature): + for typs in expand_tuples(signature): + self.add(typs, func) + return + + signature = (Variadic[tp] if isinstance(tp, list) else tp for tp in signature) + signature = tuple(map(typing_wrap, map(_type_to_typing, signature))) + + super().add(signature, func) + if getattr(self, "default", None): # XXX should this class have default? - objects = (typing_wrap(typing.Any),) * len(types) - if objects != types and deep_supercedes(types, objects): - super().register(*objects)(self.default) - return super().register(*types) + objects = (typing_wrap(typing.Any),) * len(signature) + if objects != signature and deep_supercedes(signature, objects): + super().add(objects, self.default) def partial_call(self, *args): """ From 6509089522472613722202930ff9a20b7e694cd8 Mon Sep 17 00:00:00 2001 From: Eli Date: Sat, 30 Jan 2021 19:34:05 -0500 Subject: [PATCH 26/66] use new variadic throughout --- funsor/cnf.py | 2 +- funsor/joint.py | 2 +- funsor/optimizer.py | 2 +- funsor/tensor.py | 2 +- funsor/terms.py | 3 +-- funsor/typing.py | 3 +-- 6 files changed, 6 insertions(+), 8 deletions(-) diff --git a/funsor/cnf.py b/funsor/cnf.py index 71a9cc59c..c5844e8a9 100644 --- a/funsor/cnf.py +++ b/funsor/cnf.py @@ -8,7 +8,6 @@ from typing import Tuple, Union import opt_einsum -from multipledispatch.variadic import Variadic import funsor import funsor.ops as ops @@ -33,6 +32,7 @@ reflect, to_funsor ) +from funsor.typing import Variadic from funsor.util import broadcast_shape, get_backend, quote diff --git a/funsor/joint.py b/funsor/joint.py index 09fac9084..cee71fb34 100644 --- a/funsor/joint.py +++ b/funsor/joint.py @@ -7,7 +7,6 @@ from typing import Tuple, Union from multipledispatch import dispatch -from multipledispatch.variadic import Variadic import funsor.ops as ops from funsor.cnf import Contraction, GaussianMixture @@ -17,6 +16,7 @@ from funsor.ops import AssociativeOp from funsor.tensor import Tensor, align_tensor from funsor.terms import Funsor, Independent, Number, Reduce, Unary, eager, moment_matching, normalize +from funsor.typing import Variadic @dispatch(str, str, Variadic[(Gaussian, GaussianMixture)]) diff --git a/funsor/optimizer.py b/funsor/optimizer.py index aad009771..5b2371a8c 100644 --- a/funsor/optimizer.py +++ b/funsor/optimizer.py @@ -3,13 +3,13 @@ import collections -from multipledispatch.variadic import Variadic from opt_einsum.paths import greedy import funsor.interpreter as interpreter from funsor.cnf import Contraction, nullop from funsor.ops import DISTRIBUTIVE_OPS, AssociativeOp from funsor.terms import Funsor, eager, lazy, normalize +from funsor.typing import Variadic @interpreter.dispatched_interpretation diff --git a/funsor/tensor.py b/funsor/tensor.py index e4ec652fa..245e940ac 100644 --- a/funsor/tensor.py +++ b/funsor/tensor.py @@ -12,7 +12,6 @@ import numpy as np import opt_einsum from multipledispatch import dispatch -from multipledispatch.variadic import Variadic import funsor import funsor.ops as ops @@ -35,6 +34,7 @@ to_funsor ) from funsor.util import get_backend, get_tracing_state, getargspec, is_nn_module, quote +from funsor.typing import Variadic def get_default_prototype(): diff --git a/funsor/terms.py b/funsor/terms.py index 4d791e72c..fc91f9037 100644 --- a/funsor/terms.py +++ b/funsor/terms.py @@ -13,14 +13,13 @@ from weakref import WeakValueDictionary from multipledispatch import dispatch -from multipledispatch.variadic import Variadic import funsor.interpreter as interpreter import funsor.ops as ops from funsor.domains import Array, Bint, Domain, Product, Real, find_domain from funsor.interpreter import PatternMissingError, dispatched_interpretation, interpret from funsor.ops import AssociativeOp, GetitemOp, Op -from funsor.typing import GenericTypeMeta, deep_type +from funsor.typing import GenericTypeMeta, Variadic, deep_type from funsor.util import getargspec, get_backend, lazy_property, pretty, quote diff --git a/funsor/typing.py b/funsor/typing.py index 68d822441..00e33b708 100644 --- a/funsor/typing.py +++ b/funsor/typing.py @@ -263,8 +263,7 @@ def partial_call(self, *args): """ Likde :meth:`__call__` but avoids calling ``func()``. """ - types = tuple(map(deep_type, args)) - types = tuple(map(_type_to_typing, types)) + types = tuple(map(typing_wrap, map(_type_to_typing, map(deep_type, args)))) try: func = self._cache[types] except KeyError: From dec06a212252156fd7ef059ef739efbc1904aa3c Mon Sep 17 00:00:00 2001 From: Eli Date: Sat, 30 Jan 2021 21:02:48 -0500 Subject: [PATCH 27/66] nits --- funsor/typing.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/funsor/typing.py b/funsor/typing.py index 00e33b708..aab3cceff 100644 --- a/funsor/typing.py +++ b/funsor/typing.py @@ -101,7 +101,6 @@ def _deep_issubclass(subcls, cls): return issubclass(subcls, cls) -@functools.lru_cache(maxsize=None) def _type_to_typing(tp): if tp is object: tp = typing.Any @@ -127,8 +126,6 @@ def get_origin(tp): def get_type_hints(obj, globalns=None, localns=None, include_extras=False): - if isinstance(obj, GenericTypeMeta): - return getattr(obj, "__annotations__", {}) return typing_extensions.get_type_hints(obj, globalns=globalns, localns=localns, include_extras=include_extras) @@ -228,7 +225,7 @@ class Variadic(metaclass=_DeepVariadicSignatureType): """ A typing-compatible drop-in replacement for multipledispatch.variadic.Variadic. """ - pass # TODO is there anything else to do here? + pass class TypingDispatcher(Dispatcher): From 38c40dfd2e2ff132c3ddce909b1688bae5fd66b8 Mon Sep 17 00:00:00 2001 From: Eli Date: Sat, 30 Jan 2021 21:46:06 -0500 Subject: [PATCH 28/66] move dispatcher to registry --- funsor/registry.py | 47 ++++++++++++++++++++++++++++++++++-- funsor/typing.py | 59 ++-------------------------------------------- 2 files changed, 47 insertions(+), 59 deletions(-) diff --git a/funsor/registry.py b/funsor/registry.py index 9ad808ca6..5c6ef35ab 100644 --- a/funsor/registry.py +++ b/funsor/registry.py @@ -3,16 +3,59 @@ from collections import defaultdict -from funsor.typing import TypingDispatcher, get_origin +from multipledispatch.dispatcher import Dispatcher, expand_tuples +from funsor.typing import Variadic, deep_type, get_origin, typing_wrap -class PartialDispatcher(TypingDispatcher): + +class PartialDispatcher(Dispatcher): """ Wrapper to avoid appearance in stack traces. """ def __init__(self, name, default=None): self.default = default if default is None else PartialDefault(default) super().__init__(name) + if default is not None: + self.add((Variadic[object],), self.default) + + def add(self, signature, func): + + # Handle annotations + if not signature: + annotations = self.get_func_annotations(func) + if annotations: + signature = annotations + + # Handle some union types by expanding at registration time + if any(isinstance(typ, tuple) for typ in signature): + for typs in expand_tuples(signature): + self.add(typs, func) + return + + # Handle variadic types + signature = (Variadic[tp] if isinstance(tp, list) else tp for tp in signature) + + signature = tuple(map(typing_wrap, signature)) + super().add(signature, func) + + def partial_call(self, *args): + """ + Likde :meth:`__call__` but avoids calling ``func()``. + """ + types = tuple(map(typing_wrap, map(deep_type, args))) + try: + func = self._cache[types] + except KeyError: + func = self.dispatch(*types) + if func is None: + raise NotImplementedError( + 'Could not find signature for %s: <%s>' % + (self.name, ', '.join(cls.__name__ for cls in types))) + self._cache[types] = func + return func + + def __call__(self, *args): + return self.partial_call(*args)(*args) class PartialDefault: diff --git a/funsor/typing.py b/funsor/typing.py index aab3cceff..afd7007f2 100644 --- a/funsor/typing.py +++ b/funsor/typing.py @@ -6,8 +6,6 @@ import typing_extensions import weakref -from multipledispatch.conflict import supercedes -from multipledispatch.dispatcher import Dispatcher, expand_tuples from multipledispatch.variadic import isvariadic from multipledispatch.variadic import Variadic as _OrigVariadic @@ -192,6 +190,7 @@ def classname(cls): class _RuntimeSubclassCheckMeta(GenericTypeMeta): def __call__(cls, tp): + tp = _type_to_typing(tp) return tp if isinstance(tp, GenericTypeMeta) or isvariadic(tp) else cls[tp] def __subclasscheck__(cls, subcls): @@ -207,18 +206,12 @@ class typing_wrap(metaclass=_RuntimeSubclassCheckMeta): pass -def deep_supercedes(xs, ys): - """typing-compatible version of multipledispatch.conflict.supercedes""" - return supercedes(tuple(typing_wrap(_type_to_typing(x)) for x in xs), - tuple(typing_wrap(_type_to_typing(y)) for y in ys)) - - class _DeepVariadicSignatureType(type): def __getitem__(cls, key): if not isinstance(key, tuple): key = (key,) - return _OrigVariadic[tuple(map(typing_wrap, map(_type_to_typing, key)))] + return _OrigVariadic[tuple(map(typing_wrap, key))] class Variadic(metaclass=_DeepVariadicSignatureType): @@ -226,51 +219,3 @@ class Variadic(metaclass=_DeepVariadicSignatureType): A typing-compatible drop-in replacement for multipledispatch.variadic.Variadic. """ pass - - -class TypingDispatcher(Dispatcher): - """ - A Dispatcher class designed for compatibility with the typing standard library. - """ - def add(self, signature, func): - - # Handle annotations - if not signature: - annotations = self.get_func_annotations(func) - if annotations: - signature = annotations - - # Handle union types - if any(isinstance(typ, tuple) for typ in signature): - for typs in expand_tuples(signature): - self.add(typs, func) - return - - signature = (Variadic[tp] if isinstance(tp, list) else tp for tp in signature) - signature = tuple(map(typing_wrap, map(_type_to_typing, signature))) - - super().add(signature, func) - - if getattr(self, "default", None): # XXX should this class have default? - objects = (typing_wrap(typing.Any),) * len(signature) - if objects != signature and deep_supercedes(signature, objects): - super().add(objects, self.default) - - def partial_call(self, *args): - """ - Likde :meth:`__call__` but avoids calling ``func()``. - """ - types = tuple(map(typing_wrap, map(_type_to_typing, map(deep_type, args)))) - try: - func = self._cache[types] - except KeyError: - func = self.dispatch(*types) - if func is None: - raise NotImplementedError( - 'Could not find signature for %s: <%s>' % - (self.name, ', '.join(cls.__name__ for cls in types))) - self._cache[types] = func - return func - - def __call__(self, *args): - return self.partial_call(*args)(*args) From 8b9fc57cf0015339210483c4adc68950aec4f148 Mon Sep 17 00:00:00 2001 From: Eli Date: Sat, 30 Jan 2021 21:54:18 -0500 Subject: [PATCH 29/66] remove aliasing and move classname to FunsorMeta --- funsor/terms.py | 4 ++++ funsor/typing.py | 15 +++------------ 2 files changed, 7 insertions(+), 12 deletions(-) diff --git a/funsor/terms.py b/funsor/terms.py index fc91f9037..d511e7753 100644 --- a/funsor/terms.py +++ b/funsor/terms.py @@ -238,6 +238,10 @@ def __call__(cls, *args, **kwargs): return interpret(cls, *args) + @property + def classname(cls): + return repr(cls) + def _convert_reduced_vars(reduced_vars, inputs): """ diff --git a/funsor/typing.py b/funsor/typing.py index afd7007f2..19342d834 100644 --- a/funsor/typing.py +++ b/funsor/typing.py @@ -20,20 +20,9 @@ def deep_isinstance(obj, cls): return deep_issubclass(deep_type(obj), cls) -def deep_issubclass(subcls, cls): - """replaces issubclass()""" - # return pytypes.is_subtype(subcls, cls) - return _deep_issubclass(subcls, cls) - - def deep_type(obj): """replaces type()""" # return pytypes.deep_type(obj) - return _deep_type(obj) - - -def _deep_type(obj): - if isinstance(obj, tuple): return typing.Tuple[tuple(map(deep_type, obj))] if obj else typing.Tuple @@ -44,7 +33,9 @@ def _deep_type(obj): @functools.lru_cache(maxsize=None) -def _deep_issubclass(subcls, cls): +def deep_issubclass(subcls, cls): + """replaces issubclass()""" + # return pytypes.is_subtype(subcls, cls) if get_origin(cls) is typing.Union: return any(_deep_issubclass(subcls, arg) for arg in get_args(cls)) From d21239433e7dc058780af59c222b06a051894d81 Mon Sep 17 00:00:00 2001 From: Eli Date: Sat, 30 Jan 2021 21:58:48 -0500 Subject: [PATCH 30/66] fix typo --- funsor/typing.py | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/funsor/typing.py b/funsor/typing.py index 19342d834..0ba188a6d 100644 --- a/funsor/typing.py +++ b/funsor/typing.py @@ -38,10 +38,10 @@ def deep_issubclass(subcls, cls): # return pytypes.is_subtype(subcls, cls) if get_origin(cls) is typing.Union: - return any(_deep_issubclass(subcls, arg) for arg in get_args(cls)) + return any(deep_issubclass(subcls, arg) for arg in get_args(cls)) if get_origin(subcls) is typing.Union: - return all(_deep_issubclass(arg, cls) for arg in get_args(subcls)) + return all(deep_issubclass(arg, cls) for arg in get_args(subcls)) if cls is typing.Any: return True @@ -61,7 +61,7 @@ def deep_issubclass(subcls, cls): return get_args(cls)[0] is typing.Any return len(get_args(subcls)) == len(get_args(cls)) == 1 and \ - _deep_issubclass(get_args(subcls)[0], get_args(cls)[0]) + deep_issubclass(get_args(subcls)[0], get_args(cls)[0]) if issubclass(get_origin(cls), typing.Tuple): @@ -76,8 +76,8 @@ def deep_issubclass(subcls, cls): if get_args(cls)[-1] is Ellipsis: # cls variadic if get_args(subcls)[-1] is Ellipsis: # both variadic - return _deep_issubclass(get_args(subcls)[0], get_args(cls)[0]) - return all(_deep_issubclass(a, get_args(cls)[0]) for a in get_args(subcls)) + return deep_issubclass(get_args(subcls)[0], get_args(cls)[0]) + return all(deep_issubclass(a, get_args(cls)[0]) for a in get_args(subcls)) if get_args(subcls)[-1] is Ellipsis: # only subcls variadic # issubclass(Tuple[A, ...], Tuple[X, Y]) == False @@ -85,7 +85,7 @@ def deep_issubclass(subcls, cls): # neither variadic return len(get_args(cls)) == len(get_args(subcls)) and \ - all(_deep_issubclass(a, b) for a, b in zip(get_args(subcls), get_args(cls))) + all(deep_issubclass(a, b) for a, b in zip(get_args(subcls), get_args(cls))) return issubclass(subcls, cls) @@ -170,10 +170,6 @@ def __repr__(cls): "" if not get_args(cls) else "[{}]".format(", ".join(repr(t) for t in get_args(cls)))) - @property - def classname(cls): - return repr(cls) - ############################################################## # Tools and overrides for typing-compatible multipledispatch From c36dd4833c4b65c341bcfbf23c96a2bbdf1198b6 Mon Sep 17 00:00:00 2001 From: Eli Date: Sat, 30 Jan 2021 22:16:12 -0500 Subject: [PATCH 31/66] Remove changes to domains from this branch --- funsor/domains.py | 114 ++++++++++++++++++++++++++++++++++------------ 1 file changed, 85 insertions(+), 29 deletions(-) diff --git a/funsor/domains.py b/funsor/domains.py index ba7111b42..75fcfe3b4 100644 --- a/funsor/domains.py +++ b/funsor/domains.py @@ -6,20 +6,20 @@ import operator import warnings from functools import reduce +from weakref import WeakValueDictionary import funsor.ops as ops -from funsor.typing import GenericTypeMeta from funsor.util import broadcast_shape, get_backend, get_tracing_state, quote - -class Domain(GenericTypeMeta): - pass +Domain = type class ArrayType(Domain): """ Base class of array-like domains. """ + _type_cache = WeakValueDictionary() + def __getitem__(cls, dtype_shape): dtype, shape = dtype_shape assert dtype is not None @@ -32,7 +32,23 @@ def __getitem__(cls, dtype_shape): if shape is not None: shape = tuple(map(int, shape)) - return super().__getitem__((dtype, shape)) + assert cls.dtype in (None, dtype) + assert cls.shape in (None, shape) + key = dtype, shape + result = ArrayType._type_cache.get(key, None) + if result is None: + if dtype == "real": + assert all(isinstance(size, int) and size >= 0 for size in shape) + name = "Reals[{}]".format(",".join(map(str, shape))) if shape else "Real" + result = RealsType(name, (), {"shape": shape}) + elif isinstance(dtype, int): + assert dtype >= 0 + name = "Bint[{}, {}]".format(dtype, ",".join(map(str, shape))) + result = BintType(name, (), {"dtype": dtype, "shape": shape}) + else: + raise ValueError("invalid dtype: {}".format(dtype)) + ArrayType._type_cache[key] = result + return result def __subclasscheck__(cls, subcls): if not isinstance(subcls, ArrayType): @@ -43,26 +59,16 @@ def __subclasscheck__(cls, subcls): return False return True - @property - def dtype(cls): - return cls.__args__[0] + def __repr__(cls): + return cls.__name__ - @property - def shape(cls): - return cls.__args__[1] + def __str__(cls): + return cls.__name__ @property def num_elements(cls): return reduce(operator.mul, cls.shape, 1) - @property - def size(cls): - return cls.dtype - - def __iter__(cls): - from funsor.terms import Number - return (Number(i, cls.size) for i in range(cls.size)) - class BintType(ArrayType): def __getitem__(cls, size_shape): @@ -70,19 +76,47 @@ def __getitem__(cls, size_shape): size, shape = size_shape[0], size_shape[1:] else: size, shape = size_shape, () - return Array.__getitem__((size, shape)) + return super().__getitem__((size, shape)) + + def __subclasscheck__(cls, subcls): + if not isinstance(subcls, BintType): + return False + if cls.dtype not in (None, subcls.dtype): + return False + if cls.shape not in (None, subcls.shape): + return False + return True + + @property + def size(cls): + return cls.dtype + + def __iter__(cls): + from funsor.terms import Number + return (Number(i, cls.size) for i in range(cls.size)) class RealsType(ArrayType): + dtype = "real" + def __getitem__(cls, shape): if not isinstance(shape, tuple): shape = (shape,) - return Array.__getitem__(("real", shape)) + return super().__getitem__(("real", shape)) + + def __subclasscheck__(cls, subcls): + if not isinstance(subcls, RealsType): + return False + if cls.dtype not in (None, subcls.dtype): + return False + if cls.shape not in (None, subcls.shape): + return False + return True def _pickle_array(cls): - if cls in (Array, Bint, Reals): - return repr(cls) + if cls in (Array, Bint, Real, Reals): + return cls.__name__ return operator.getitem, (Array, (cls.dtype, cls.shape)) @@ -98,20 +132,22 @@ class Array(metaclass=ArrayType): Arary["real", (3, 3)] = Reals[3, 3] Array["real", ()] = Real """ - pass + dtype = None + shape = None -class Bint(Array, metaclass=BintType): +class Bint(metaclass=BintType): """ Factory for bounded integer types:: Bint[5] # integers ranging in {0,1,2,3,4} Bint[2, 3, 3] # 3x3 matrices with entries in {0,1} """ - pass + dtype = None + shape = None -class Reals(Array, metaclass=RealsType): +class Reals(metaclass=RealsType): """ Type of a real-valued array with known shape:: @@ -119,7 +155,7 @@ class Reals(Array, metaclass=RealsType): Reals[8] # vector of length 8 Reals[3, 3] # 3x3 matrix """ - pass + shape = None Real = Reals[()] @@ -140,6 +176,26 @@ def bint(size): class ProductDomain(Domain): + + _type_cache = WeakValueDictionary() + + def __getitem__(cls, arg_domains): + try: + return ProductDomain._type_cache[arg_domains] + except KeyError: + assert isinstance(arg_domains, tuple) + assert all(isinstance(arg_domain, Domain) for arg_domain in arg_domains) + subcls = type("Product_", (Product,), {"__args__": arg_domains}) + ProductDomain._type_cache[arg_domains] = subcls + return subcls + + def __repr__(cls): + return "Product[{}]".format(", ".join(map(repr, cls.__args__))) + + @property + def __origin__(cls): + return Product + @property def shape(cls): return (len(cls.__args__),) @@ -147,7 +203,7 @@ def shape(cls): class Product(tuple, metaclass=ProductDomain): """like typing.Tuple, but works with issubclass""" - pass + __args__ = NotImplemented @quote.register(BintType) From 0b71b48d68f55733a2c043ad81d4cba8a8e190d8 Mon Sep 17 00:00:00 2001 From: Eli Date: Sat, 30 Jan 2021 22:21:07 -0500 Subject: [PATCH 32/66] keep classname lazy --- funsor/terms.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/funsor/terms.py b/funsor/terms.py index d511e7753..65790a277 100644 --- a/funsor/terms.py +++ b/funsor/terms.py @@ -238,7 +238,7 @@ def __call__(cls, *args, **kwargs): return interpret(cls, *args) - @property + @lazy_property def classname(cls): return repr(cls) From 78bb134430493616fbee9682b5d89bd4d97138c8 Mon Sep 17 00:00:00 2001 From: Eli Date: Tue, 2 Feb 2021 13:45:55 -0500 Subject: [PATCH 33/66] run black --- examples/discrete_hmm.py | 2 +- examples/eeg_slds.py | 4 ++-- examples/kalman_filter.py | 2 +- examples/mixed_hmm/model.py | 2 +- examples/sensor.py | 2 +- examples/slds.py | 2 +- examples/vae.py | 2 +- funsor/__init__.py | 3 +-- funsor/adjoint.py | 2 +- funsor/affine.py | 2 +- funsor/distribution.py | 16 +++++++++++----- funsor/einsum/numpy_map.py | 1 - funsor/integrate.py | 2 +- funsor/jax/__init__.py | 2 +- funsor/jax/distributions.py | 8 +++----- funsor/joint.py | 3 +-- funsor/pyro/distribution.py | 2 +- funsor/registry.py | 3 +-- funsor/tensor.py | 4 ++-- funsor/terms.py | 5 ++--- funsor/testing.py | 8 ++++---- funsor/torch/__init__.py | 6 +++--- funsor/typing.py | 6 +++--- scripts/update_headers.py | 2 +- test/examples/test_bart.py | 2 +- test/examples/test_sensor_fusion.py | 2 +- test/test_affine.py | 2 +- test/test_cnf.py | 5 ++--- test/test_domains.py | 2 +- test/test_gaussian.py | 3 +-- test/test_joint.py | 3 +-- test/test_memoize.py | 1 - test/test_ops.py | 2 +- test/test_sum_product.py | 6 +++--- test/test_tensor.py | 2 +- 35 files changed, 58 insertions(+), 63 deletions(-) diff --git a/examples/discrete_hmm.py b/examples/discrete_hmm.py index 20b9435ca..1ff8d3718 100644 --- a/examples/discrete_hmm.py +++ b/examples/discrete_hmm.py @@ -7,8 +7,8 @@ import torch import funsor -import funsor.torch.distributions as dist import funsor.ops as ops +import funsor.torch.distributions as dist from funsor.interpreter import interpretation, reinterpret from funsor.optimizer import apply_optimizer from funsor.terms import lazy diff --git a/examples/eeg_slds.py b/examples/eeg_slds.py index aef92a9ad..d9eb8107a 100644 --- a/examples/eeg_slds.py +++ b/examples/eeg_slds.py @@ -17,13 +17,13 @@ from urllib.request import urlopen import numpy as np +import pyro import torch import torch.nn as nn -import pyro import funsor -import funsor.torch.distributions as dist import funsor.ops as ops +import funsor.torch.distributions as dist from funsor.pyro.convert import funsor_to_cat_and_mvn, funsor_to_mvn, matrix_and_mvn_to_funsor, mvn_to_funsor diff --git a/examples/kalman_filter.py b/examples/kalman_filter.py index 0f807a436..2a8f7862e 100644 --- a/examples/kalman_filter.py +++ b/examples/kalman_filter.py @@ -6,8 +6,8 @@ import torch import funsor -import funsor.torch.distributions as dist import funsor.ops as ops +import funsor.torch.distributions as dist from funsor.interpreter import interpretation, reinterpret from funsor.optimizer import apply_optimizer from funsor.terms import lazy diff --git a/examples/mixed_hmm/model.py b/examples/mixed_hmm/model.py index b84479665..8257e7f07 100644 --- a/examples/mixed_hmm/model.py +++ b/examples/mixed_hmm/model.py @@ -7,8 +7,8 @@ import torch from torch.distributions import constraints -import funsor.torch.distributions as dist import funsor.ops as ops +import funsor.torch.distributions as dist from funsor.domains import Bint, Reals from funsor.tensor import Tensor from funsor.terms import Stack, Variable, to_funsor diff --git a/examples/sensor.py b/examples/sensor.py index fca7d2f55..fff252ec8 100644 --- a/examples/sensor.py +++ b/examples/sensor.py @@ -12,8 +12,8 @@ from torch.optim import Adam import funsor -import funsor.torch.distributions as f_dist import funsor.ops as ops +import funsor.torch.distributions as f_dist from funsor.domains import Reals from funsor.pyro.convert import dist_to_funsor, funsor_to_mvn from funsor.tensor import Tensor, Variable diff --git a/examples/slds.py b/examples/slds.py index 9ade1fd88..e69b11033 100644 --- a/examples/slds.py +++ b/examples/slds.py @@ -6,8 +6,8 @@ import torch import funsor -import funsor.torch.distributions as dist import funsor.ops as ops +import funsor.torch.distributions as dist def main(args): diff --git a/examples/vae.py b/examples/vae.py index 81b48b818..2a5cfbb47 100644 --- a/examples/vae.py +++ b/examples/vae.py @@ -12,8 +12,8 @@ from torchvision import datasets, transforms import funsor -import funsor.torch.distributions as dist import funsor.ops as ops +import funsor.torch.distributions as dist from funsor.domains import Bint, Reals REPO_PATH = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) diff --git a/funsor/__init__.py b/funsor/__init__.py index 6b2c1e4f1..0c34db9d6 100644 --- a/funsor/__init__.py +++ b/funsor/__init__.py @@ -3,7 +3,7 @@ from funsor.domains import Array, Bint, Domain, Real, Reals, bint, find_domain, reals from funsor.integrate import Integrate -from funsor.interpreter import reinterpret, interpretation +from funsor.interpreter import interpretation, reinterpret from funsor.sum_product import MarkovProduct from funsor.tensor import Tensor, function from funsor.terms import ( @@ -42,7 +42,6 @@ testing ) - __all__ = [ 'Array', 'Bint', diff --git a/funsor/adjoint.py b/funsor/adjoint.py index fd0c9c31a..64f114fea 100644 --- a/funsor/adjoint.py +++ b/funsor/adjoint.py @@ -13,8 +13,8 @@ from funsor.interpreter import interpretation from funsor.ops import AssociativeOp from funsor.registry import KeyedRegistry -from funsor.terms import Binary, Cat, Funsor, Number, Reduce, Slice, Subs, Variable, reflect, substitute, to_funsor from funsor.tensor import Tensor +from funsor.terms import Binary, Cat, Funsor, Number, Reduce, Slice, Subs, Variable, reflect, substitute, to_funsor def _alpha_unmangle(expr): diff --git a/funsor/affine.py b/funsor/affine.py index 630502ec9..9ea6e273e 100644 --- a/funsor/affine.py +++ b/funsor/affine.py @@ -8,7 +8,7 @@ from funsor.interpreter import gensym from funsor.tensor import Einsum, Tensor, get_default_prototype -from funsor.terms import Binary, Funsor, Lambda, Reduce, Unary, Variable, Bint +from funsor.terms import Binary, Bint, Funsor, Lambda, Reduce, Unary, Variable from . import ops diff --git a/funsor/distribution.py b/funsor/distribution.py index 00e8feb48..c710ed57f 100644 --- a/funsor/distribution.py +++ b/funsor/distribution.py @@ -19,13 +19,19 @@ from funsor.domains import Array, Real, Reals from funsor.gaussian import Gaussian from funsor.interpreter import gensym -from funsor.tensor import (Function, Tensor, align_tensors, dummy_numeric_array, get_default_prototype, - ignore_jit_warnings, numeric_array, stack) -from funsor.terms import Funsor, FunsorMeta, Independent, Lambda, Number, Variable, \ - eager, reflect, to_data, to_funsor +from funsor.tensor import ( + Function, + Tensor, + align_tensors, + dummy_numeric_array, + get_default_prototype, + ignore_jit_warnings, + numeric_array, + stack +) +from funsor.terms import Funsor, FunsorMeta, Independent, Lambda, Number, Variable, eager, reflect, to_data, to_funsor from funsor.util import broadcast_shape, get_backend, getargspec, lazy_property - BACKEND_TO_DISTRIBUTIONS_BACKEND = { "torch": "funsor.torch.distributions", "jax": "funsor.jax.distributions", diff --git a/funsor/einsum/numpy_map.py b/funsor/einsum/numpy_map.py index 02950ea01..d6c7eb13a 100644 --- a/funsor/einsum/numpy_map.py +++ b/funsor/einsum/numpy_map.py @@ -2,7 +2,6 @@ # SPDX-License-Identifier: Apache-2.0 import operator - from functools import reduce import funsor.ops as ops diff --git a/funsor/integrate.py b/funsor/integrate.py index 3dce6a5f9..2248eb500 100644 --- a/funsor/integrate.py +++ b/funsor/integrate.py @@ -7,7 +7,7 @@ import funsor.ops as ops from funsor.cnf import Contraction, GaussianMixture from funsor.delta import Delta -from funsor.gaussian import Gaussian, align_gaussian, _mv, _trace_mm, _vv +from funsor.gaussian import Gaussian, _mv, _trace_mm, _vv, align_gaussian from funsor.tensor import Tensor from funsor.terms import ( Funsor, diff --git a/funsor/jax/__init__.py b/funsor/jax/__init__.py index 008c33a4e..ec39c1f82 100644 --- a/funsor/jax/__init__.py +++ b/funsor/jax/__init__.py @@ -9,8 +9,8 @@ import funsor.ops as ops from funsor.adjoint import adjoint_ops from funsor.interpreter import children, recursion_reinterpret -from funsor.terms import Funsor, to_funsor from funsor.tensor import Tensor, tensor_to_funsor +from funsor.terms import Funsor, to_funsor from funsor.util import quote diff --git a/funsor/jax/distributions.py b/funsor/jax/distributions.py index 14a9d14e1..761d0949d 100644 --- a/funsor/jax/distributions.py +++ b/funsor/jax/distributions.py @@ -4,12 +4,12 @@ import functools from typing import Tuple, Union +import funsor.ops as ops import numpyro.distributions as dist - from funsor.cnf import Contraction from funsor.distribution import ( # noqa: F401 - Bernoulli, FUNSOR_DIST_NAMES, + Bernoulli, LogNormal, backenddist_to_funsor, eager_beta, @@ -34,15 +34,13 @@ indepdist_to_funsor, make_dist, maskeddist_to_funsor, - transformeddist_to_funsor, + transformeddist_to_funsor ) from funsor.domains import Real, Reals -import funsor.ops as ops from funsor.tensor import Tensor from funsor.terms import Binary, Funsor, Reduce, Variable, eager, to_data, to_funsor from funsor.util import methodof - ################################################################################ # Distribution Wrappers ################################################################################ diff --git a/funsor/joint.py b/funsor/joint.py index cee71fb34..4dc48f431 100644 --- a/funsor/joint.py +++ b/funsor/joint.py @@ -6,8 +6,6 @@ from functools import reduce from typing import Tuple, Union -from multipledispatch import dispatch - import funsor.ops as ops from funsor.cnf import Contraction, GaussianMixture from funsor.delta import Delta @@ -17,6 +15,7 @@ from funsor.tensor import Tensor, align_tensor from funsor.terms import Funsor, Independent, Number, Reduce, Unary, eager, moment_matching, normalize from funsor.typing import Variadic +from multipledispatch import dispatch @dispatch(str, str, Variadic[(Gaussian, GaussianMixture)]) diff --git a/funsor/pyro/distribution.py b/funsor/pyro/distribution.py index 078cb02b3..cea5593de 100644 --- a/funsor/pyro/distribution.py +++ b/funsor/pyro/distribution.py @@ -4,8 +4,8 @@ from collections import OrderedDict import torch -from torch.distributions import constraints from pyro.distributions import TorchDistribution +from torch.distributions import constraints from funsor.cnf import Contraction from funsor.delta import Delta diff --git a/funsor/registry.py b/funsor/registry.py index 5c6ef35ab..853e1c3b7 100644 --- a/funsor/registry.py +++ b/funsor/registry.py @@ -3,9 +3,8 @@ from collections import defaultdict -from multipledispatch.dispatcher import Dispatcher, expand_tuples - from funsor.typing import Variadic, deep_type, get_origin, typing_wrap +from multipledispatch.dispatcher import Dispatcher, expand_tuples class PartialDispatcher(Dispatcher): diff --git a/funsor/tensor.py b/funsor/tensor.py index 245e940ac..c498d3af9 100644 --- a/funsor/tensor.py +++ b/funsor/tensor.py @@ -11,7 +11,6 @@ import numpy as np import opt_einsum -from multipledispatch import dispatch import funsor import funsor.ops as ops @@ -33,8 +32,9 @@ to_data, to_funsor ) -from funsor.util import get_backend, get_tracing_state, getargspec, is_nn_module, quote from funsor.typing import Variadic +from funsor.util import get_backend, get_tracing_state, getargspec, is_nn_module, quote +from multipledispatch import dispatch def get_default_prototype(): diff --git a/funsor/terms.py b/funsor/terms.py index 65790a277..2cc6bdd8c 100644 --- a/funsor/terms.py +++ b/funsor/terms.py @@ -12,15 +12,14 @@ from functools import reduce, singledispatch from weakref import WeakValueDictionary -from multipledispatch import dispatch - import funsor.interpreter as interpreter import funsor.ops as ops from funsor.domains import Array, Bint, Domain, Product, Real, find_domain from funsor.interpreter import PatternMissingError, dispatched_interpretation, interpret from funsor.ops import AssociativeOp, GetitemOp, Op from funsor.typing import GenericTypeMeta, Variadic, deep_type -from funsor.util import getargspec, get_backend, lazy_property, pretty, quote +from funsor.util import get_backend, getargspec, lazy_property, pretty, quote +from multipledispatch import dispatch def substitute(expr, subs): diff --git a/funsor/testing.py b/funsor/testing.py index b50701fb7..d71e110a7 100644 --- a/funsor/testing.py +++ b/funsor/testing.py @@ -11,17 +11,17 @@ import numpy as np import opt_einsum -from multipledispatch import dispatch -from multipledispatch.variadic import Variadic import funsor.ops as ops from funsor.cnf import Contraction from funsor.delta import Delta -from funsor.domains import Domain, Bint, Real +from funsor.domains import Bint, Domain, Real from funsor.gaussian import Gaussian -from funsor.terms import Funsor, Number from funsor.tensor import Tensor +from funsor.terms import Funsor, Number from funsor.util import get_backend +from multipledispatch import dispatch +from multipledispatch.variadic import Variadic @contextlib.contextmanager diff --git a/funsor/torch/__init__.py b/funsor/torch/__init__.py index e5731ef96..eccb81a01 100644 --- a/funsor/torch/__init__.py +++ b/funsor/torch/__init__.py @@ -2,16 +2,16 @@ # SPDX-License-Identifier: Apache-2.0 import torch -from multipledispatch import dispatch +import funsor.ops as ops import funsor.torch.distributions # noqa: F401 import funsor.torch.ops # noqa: F401 -import funsor.ops as ops from funsor.adjoint import adjoint_ops from funsor.interpreter import children, recursion_reinterpret -from funsor.terms import Funsor, to_funsor from funsor.tensor import Tensor, tensor_to_funsor +from funsor.terms import Funsor, to_funsor from funsor.util import quote +from multipledispatch import dispatch @adjoint_ops.register(Tensor, ops.AssociativeOp, ops.AssociativeOp, Funsor, torch.Tensor, tuple, object) diff --git a/funsor/typing.py b/funsor/typing.py index 0ba188a6d..4dc46dc8b 100644 --- a/funsor/typing.py +++ b/funsor/typing.py @@ -3,12 +3,12 @@ import functools import typing -import typing_extensions import weakref -from multipledispatch.variadic import isvariadic -from multipledispatch.variadic import Variadic as _OrigVariadic +import typing_extensions +from multipledispatch.variadic import Variadic as _OrigVariadic +from multipledispatch.variadic import isvariadic ################################# # Runtime type-checking helpers diff --git a/scripts/update_headers.py b/scripts/update_headers.py index 5880550cd..f71581ecd 100644 --- a/scripts/update_headers.py +++ b/scripts/update_headers.py @@ -1,8 +1,8 @@ # Copyright Contributors to the Pyro project. # SPDX-License-Identifier: Apache-2.0 -import os import glob +import os root = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) blacklist = ["/build/", "/dist/"] diff --git a/test/examples/test_bart.py b/test/examples/test_bart.py index b29dabe08..5658c8e50 100644 --- a/test/examples/test_bart.py +++ b/test/examples/test_bart.py @@ -7,8 +7,8 @@ import torch import funsor -import funsor.torch.distributions as dist import funsor.ops as ops +import funsor.torch.distributions as dist from funsor.cnf import Contraction from funsor.domains import Bint, Real, Reals from funsor.gaussian import Gaussian diff --git a/test/examples/test_sensor_fusion.py b/test/examples/test_sensor_fusion.py index 3a60aa6b0..92011dc20 100644 --- a/test/examples/test_sensor_fusion.py +++ b/test/examples/test_sensor_fusion.py @@ -6,8 +6,8 @@ import pytest import torch -import funsor.torch.distributions as dist import funsor.ops as ops +import funsor.torch.distributions as dist from funsor.cnf import Contraction from funsor.domains import Bint, Reals from funsor.gaussian import Gaussian diff --git a/test/test_affine.py b/test/test_affine.py index 78246c30e..64ed517e0 100644 --- a/test/test_affine.py +++ b/test/test_affine.py @@ -8,9 +8,9 @@ from funsor.affine import extract_affine, is_affine from funsor.cnf import Contraction from funsor.domains import Bint, Real, Reals # noqa: F401 +from funsor.tensor import Einsum, Tensor from funsor.terms import Number, Unary, Variable from funsor.testing import assert_close, check_funsor, ones, randn, random_gaussian, random_tensor # noqa: F401 -from funsor.tensor import Einsum, Tensor assert random_gaussian # flake8 diff --git a/test/test_cnf.py b/test/test_cnf.py index a5e184955..79c5c9560 100644 --- a/test/test_cnf.py +++ b/test/test_cnf.py @@ -8,9 +8,8 @@ import pytest from funsor import ops -from funsor.cnf import Contraction, BACKEND_TO_EINSUM_BACKEND, BACKEND_TO_LOGSUMEXP_BACKEND -from funsor.domains import Array, Bint # noqa F403 -from funsor.domains import Reals +from funsor.cnf import BACKEND_TO_EINSUM_BACKEND, BACKEND_TO_LOGSUMEXP_BACKEND, Contraction +from funsor.domains import Array, Bint, Reals # noqa F403 from funsor.einsum import einsum, naive_plated_einsum from funsor.interpreter import interpretation, reinterpret from funsor.tensor import Tensor diff --git a/test/test_domains.py b/test/test_domains.py index 31abbbe03..04bc2fa46 100644 --- a/test/test_domains.py +++ b/test/test_domains.py @@ -6,7 +6,7 @@ import pytest -from funsor.domains import Bint, Real, Reals, Bint, Reals # noqa F401 +from funsor.domains import Bint, Real, Reals # noqa F401 @pytest.mark.parametrize('expr', [ diff --git a/test/test_gaussian.py b/test/test_gaussian.py index 58e59ab61..a4cf040d6 100644 --- a/test/test_gaussian.py +++ b/test/test_gaussian.py @@ -15,8 +15,7 @@ from funsor.integrate import Integrate from funsor.tensor import Einsum, Tensor, numeric_array from funsor.terms import Number, Variable -from funsor.testing import (assert_close, id_from_inputs, ones, randn, random_gaussian, - random_tensor, zeros) +from funsor.testing import assert_close, id_from_inputs, ones, randn, random_gaussian, random_tensor, zeros from funsor.util import get_backend assert Einsum # flake8 diff --git a/test/test_joint.py b/test/test_joint.py index 83ce3889c..166526b54 100644 --- a/test/test_joint.py +++ b/test/test_joint.py @@ -17,8 +17,7 @@ from funsor.montecarlo import MonteCarlo from funsor.tensor import Tensor, numeric_array from funsor.terms import Number, Variable, eager, moment_matching -from funsor.testing import (assert_close, randn, random_gaussian, random_tensor, - zeros, xfail_if_not_implemented) +from funsor.testing import assert_close, randn, random_gaussian, random_tensor, xfail_if_not_implemented, zeros from funsor.util import get_backend diff --git a/test/test_memoize.py b/test/test_memoize.py index 79206d534..7fa7611d0 100644 --- a/test/test_memoize.py +++ b/test/test_memoize.py @@ -14,7 +14,6 @@ from funsor.testing import make_einsum_example, xfail_param from funsor.util import get_backend - EINSUM_EXAMPLES = [ ("a,b->", ''), ("ab,a->", ''), diff --git a/test/test_ops.py b/test/test_ops.py index 9db97165b..6a26a1621 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -6,8 +6,8 @@ import pytest from funsor.distribution import BACKEND_TO_DISTRIBUTIONS_BACKEND -from funsor.util import get_backend from funsor.ops import WrappedTransformOp +from funsor.util import get_backend @pytest.fixture diff --git a/test/test_sum_product.py b/test/test_sum_product.py index 7ea9854a1..f83e96abb 100644 --- a/test/test_sum_product.py +++ b/test/test_sum_product.py @@ -1,8 +1,8 @@ # Copyright Contributors to the Pyro project. # SPDX-License-Identifier: Apache-2.0 -import re import os +import re from collections import OrderedDict from functools import partial, reduce @@ -15,12 +15,12 @@ from funsor.sum_product import ( MarkovProduct, _partition, - partial_unroll, mixed_sequential_sum_product, + modified_partial_sum_product, naive_sarkka_bilmes_product, naive_sequential_sum_product, partial_sum_product, - modified_partial_sum_product, + partial_unroll, sarkka_bilmes_product, sequential_sum_product, sum_product diff --git a/test/test_tensor.py b/test/test_tensor.py index 607ed9b7c..3ffcd04e6 100644 --- a/test/test_tensor.py +++ b/test/test_tensor.py @@ -13,7 +13,7 @@ import funsor import funsor.ops as ops -from funsor.domains import Array, Bint, Real, Product, Reals, find_domain +from funsor.domains import Array, Bint, Product, Real, Reals, find_domain from funsor.interpreter import interpretation from funsor.tensor import REDUCE_OP_TO_NUMERIC, Einsum, Tensor, align_tensors, numeric_array, stack, tensordot from funsor.terms import Cat, Lambda, Number, Slice, Stack, Variable, lazy From bac807c3345250240e9541fcd37967ffe0e4e152 Mon Sep 17 00:00:00 2001 From: Eli Date: Tue, 2 Feb 2021 14:24:00 -0500 Subject: [PATCH 34/66] fix import --- funsor/torch/__init__.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/funsor/torch/__init__.py b/funsor/torch/__init__.py index a4d3d62ba..f94fdad48 100644 --- a/funsor/torch/__init__.py +++ b/funsor/torch/__init__.py @@ -3,15 +3,16 @@ import torch -import funsor.ops as ops +from multipledispatch import dispatch + import funsor.torch.distributions # noqa: F401 import funsor.torch.ops # noqa: F401 +import funsor.ops as ops from funsor.adjoint import adjoint_ops from funsor.interpreter import children, recursion_reinterpret from funsor.tensor import Tensor, tensor_to_funsor from funsor.terms import Funsor, to_funsor from funsor.util import quote -from multipledispatch import dispatch @adjoint_ops.register( From f0073426cf871b9bb116eb6a98b446d773846c99 Mon Sep 17 00:00:00 2001 From: Eli Date: Tue, 2 Feb 2021 21:59:13 -0500 Subject: [PATCH 35/66] lint --- funsor/registry.py | 3 ++- funsor/testing.py | 4 ++-- funsor/torch/__init__.py | 1 - funsor/typing.py | 1 - 4 files changed, 4 insertions(+), 5 deletions(-) diff --git a/funsor/registry.py b/funsor/registry.py index c5287d11c..0f4e8d5f7 100644 --- a/funsor/registry.py +++ b/funsor/registry.py @@ -3,9 +3,10 @@ from collections import defaultdict -from funsor.typing import Variadic, deep_type, get_origin, typing_wrap from multipledispatch.dispatcher import Dispatcher, expand_tuples +from funsor.typing import Variadic, deep_type, get_origin, typing_wrap + class PartialDispatcher(Dispatcher): """ diff --git a/funsor/testing.py b/funsor/testing.py index c5f05f7c5..ef5dc7ef9 100644 --- a/funsor/testing.py +++ b/funsor/testing.py @@ -11,6 +11,8 @@ import numpy as np import opt_einsum +from multipledispatch import dispatch +from multipledispatch.variadic import Variadic import funsor.ops as ops from funsor.cnf import Contraction @@ -20,8 +22,6 @@ from funsor.tensor import Tensor from funsor.terms import Funsor, Number from funsor.util import get_backend -from multipledispatch import dispatch -from multipledispatch.variadic import Variadic @contextlib.contextmanager diff --git a/funsor/torch/__init__.py b/funsor/torch/__init__.py index c3dc0f78c..c2c0e4831 100644 --- a/funsor/torch/__init__.py +++ b/funsor/torch/__init__.py @@ -2,7 +2,6 @@ # SPDX-License-Identifier: Apache-2.0 import torch - from multipledispatch import dispatch from funsor.adjoint import adjoint_ops diff --git a/funsor/typing.py b/funsor/typing.py index 697b4e6ca..bf98ddc67 100644 --- a/funsor/typing.py +++ b/funsor/typing.py @@ -6,7 +6,6 @@ import weakref import typing_extensions - from multipledispatch.variadic import Variadic as _OrigVariadic from multipledispatch.variadic import isvariadic From 1310bb54d23797c7d7d91912804ce5ec1198da31 Mon Sep 17 00:00:00 2001 From: Eli Date: Tue, 2 Feb 2021 22:33:10 -0500 Subject: [PATCH 36/66] try to fix python 3.6 --- funsor/typing.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/funsor/typing.py b/funsor/typing.py index bf98ddc67..734a2c694 100644 --- a/funsor/typing.py +++ b/funsor/typing.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 import functools +import sys import typing import weakref @@ -104,22 +105,22 @@ def _type_to_typing(tp): def get_args(tp): - if isinstance(tp, GenericTypeMeta): + if isinstance(tp, GenericTypeMeta) or sys.version_info[:2] < (3, 7): return getattr(tp, "__args__", ()) result = typing_extensions.get_args(tp) return () if result is None else result def get_origin(tp): - if isinstance(tp, GenericTypeMeta): + if isinstance(tp, GenericTypeMeta) or sys.version_info[:2] < (3, 7): return getattr(tp, "__origin__", tp) result = typing_extensions.get_origin(tp) return tp if result is None else result -def get_type_hints(obj, globalns=None, localns=None, include_extras=False): +def get_type_hints(obj, globalns=None, localns=None, **kwargs): return typing_extensions.get_type_hints( - obj, globalns=globalns, localns=localns, include_extras=include_extras + obj, globalns=globalns, localns=localns, **kwargs ) From 646d8e5e8f7dabe56141edfc566980f740d44c83 Mon Sep 17 00:00:00 2001 From: Eli Date: Tue, 2 Feb 2021 22:43:54 -0500 Subject: [PATCH 37/66] Add test stages with other python versions to travis --- .travis.yml | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/.travis.yml b/.travis.yml index c5878d10f..e0eb48371 100644 --- a/.travis.yml +++ b/.travis.yml @@ -45,3 +45,15 @@ jobs: - pip install https://github.com/pyro-ppl/numpyro/archive/master.zip - pip install -e .[jax] - CI=1 FUNSOR_BACKEND=jax make test + - name: numpy37 + python: 3.7 + script: + - make test + - name: numpy38 + python: 3.8 + script: + - make test + - name: numpy39 + python: 3.9 + script: + - make test From 2d79c0a02fd421ec908ac94ee271447e99ec2296 Mon Sep 17 00:00:00 2001 From: Eli Date: Tue, 2 Feb 2021 22:59:46 -0500 Subject: [PATCH 38/66] attempt to fix --- funsor/typing.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/funsor/typing.py b/funsor/typing.py index 734a2c694..8188b2daf 100644 --- a/funsor/typing.py +++ b/funsor/typing.py @@ -106,15 +106,17 @@ def _type_to_typing(tp): def get_args(tp): if isinstance(tp, GenericTypeMeta) or sys.version_info[:2] < (3, 7): - return getattr(tp, "__args__", ()) - result = typing_extensions.get_args(tp) + result = getattr(tp, "__args__", None) + else: + result = typing_extensions.get_args(tp) return () if result is None else result def get_origin(tp): if isinstance(tp, GenericTypeMeta) or sys.version_info[:2] < (3, 7): - return getattr(tp, "__origin__", tp) - result = typing_extensions.get_origin(tp) + result = getattr(tp, "__origin__", None) + else: + result = typing_extensions.get_origin(tp) return tp if result is None else result From 01016b96cf79e3d714b67650899d3e8844fcee7e Mon Sep 17 00:00:00 2001 From: Eli Date: Tue, 16 Feb 2021 20:17:24 -0500 Subject: [PATCH 39/66] remove parametric subclass tests from test_terms --- test/test_terms.py | 101 --------------------------------------------- 1 file changed, 101 deletions(-) diff --git a/test/test_terms.py b/test/test_terms.py index d9447575e..d8c3e0183 100644 --- a/test/test_terms.py +++ b/test/test_terms.py @@ -29,11 +29,9 @@ from funsor.terms import ( Binary, Cat, - Funsor, Independent, Lambda, Number, - Reduce, Slice, Stack, Subs, @@ -576,105 +574,6 @@ def test_align_simple(): assert f(x=1, y=2, z=3) == g(x=1, y=2, z=3) -@pytest.mark.parametrize( - "subcls_expr,cls_expr", - [ - ("Reduce", "Reduce"), - ("Reduce[ops.AssociativeOp, Funsor, frozenset]", "Funsor"), - ("Reduce[ops.AssociativeOp, Funsor, frozenset]", "Reduce"), - ( - "Reduce[ops.AssociativeOp, Funsor, frozenset]", - "Reduce[ops.Op, Funsor, frozenset]", - ), - ( - "Reduce[ops.AssociativeOp, Reduce[ops.AssociativeOp, Funsor, frozenset], frozenset]", - "Reduce[ops.Op, Funsor, frozenset]", - ), - ( - "Reduce[ops.AssociativeOp, Reduce[ops.AssociativeOp, Funsor, frozenset], frozenset]", - "Reduce[ops.AssociativeOp, Reduce, frozenset]", - ), - ("Stack[str, typing.Tuple[Number, Number, Number]]", "Stack"), - ("Stack[str, typing.Tuple[Number, Number, Number]]", "Stack[str, tuple]"), - # Unions - ( - "Reduce[ops.AssociativeOp, typing.Union[Number, Stack[str, typing.Tuple[Number, Number]]], frozenset]", - "Funsor", - ), - ( - "Reduce[ops.AssociativeOp, typing.Union[Number, Stack], frozenset]", - "Reduce[ops.Op, Funsor, frozenset]", - ), - ], -) -def test_parametric_subclass(subcls_expr, cls_expr): - subcls = eval(subcls_expr) - cls = eval(cls_expr) - print(subcls.classname) - print(cls.classname) - assert issubclass(cls, (Funsor, Reduce)) and not issubclass( - subcls, typing.Tuple - ) # appease flake8 - assert issubclass(subcls, cls) - - -@pytest.mark.parametrize( - "subcls_expr,cls_expr", - [ - ("Funsor", "Reduce[ops.AssociativeOp, Funsor, frozenset]"), - ("Reduce", "Reduce[ops.AssociativeOp, Funsor, frozenset]"), - ( - "Reduce[ops.Op, Funsor, frozenset]", - "Reduce[ops.AssociativeOp, Funsor, frozenset]", - ), - ( - "Reduce[ops.AssociativeOp, Reduce[ops.AssociativeOp, Funsor, frozenset], frozenset]", - "Reduce[ops.Op, Variable, frozenset]", - ), - ( - "Reduce[ops.AssociativeOp, Reduce[ops.AssociativeOp, Funsor, frozenset], frozenset]", - "Reduce[ops.AssociativeOp, Reduce[ops.AddOp, Funsor, frozenset], frozenset]", - ), - ("Stack", "Stack[str, typing.Tuple[Number, Number, Number]]"), - ("Stack[str, tuple]", "Stack[str, typing.Tuple[Number, Number, Number]]"), - ( - "Stack[str, typing.Tuple[Number, Number]]", - "Stack[str, typing.Tuple[Number, Reduce]]", - ), - ( - "Stack[str, typing.Tuple[Number, Reduce]]", - "Stack[str, typing.Tuple[Number, Number]]", - ), - # Unions - ( - "Funsor", - "Reduce[ops.AssociativeOp, typing.Union[Number, Funsor], frozenset]", - ), - ( - "Reduce[ops.Op, Funsor, frozenset]", - "Reduce[ops.AssociativeOp, typing.Union[Number, Stack], frozenset]", - ), - ( - "Reduce[typing.Union[ops.Op, ops.AssociativeOp], Stack, frozenset]", - "Reduce[ops.AssociativeOp, typing.Union[Stack[str, tuple], Reduce[ops.AssociativeOp, typing.Union[Cat, Stack], frozenset]], frozenset]", # noqa: E501 - ), - ( - "Reduce[ops.AssociativeOp, typing.Union[Stack, Reduce[ops.AssociativeOp, typing.Union[Number, Stack], frozenset]], frozenset]", # noqa: E501 - "Reduce[typing.Union[ops.Op, ops.AssociativeOp], Stack, frozenset]", - ), - ], -) -def test_not_parametric_subclass(subcls_expr, cls_expr): - subcls = eval(subcls_expr) - cls = eval(cls_expr) - print(subcls.classname) - print(cls.classname) - assert issubclass(cls, (Funsor, Reduce)) and not issubclass( - subcls, typing.Tuple - ) # appease flake8 - assert not issubclass(subcls, cls) - - @pytest.mark.parametrize( "start,stop", [(1, 3), (0, 1), (3, 7), (4, 8), (0, 2), (0, 10), (1, 2), (1, 10), (2, 10)], From ee1a611f055d64ebfb427abdc5de8997f307dff6 Mon Sep 17 00:00:00 2001 From: Eli Date: Tue, 16 Feb 2021 20:22:24 -0500 Subject: [PATCH 40/66] add new test_typing.py --- test/test_terms.py | 2 + test/test_typing.py | 122 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 124 insertions(+) create mode 100644 test/test_typing.py diff --git a/test/test_terms.py b/test/test_terms.py index d8c3e0183..e495e3e32 100644 --- a/test/test_terms.py +++ b/test/test_terms.py @@ -32,6 +32,7 @@ Independent, Lambda, Number, + Reduce, Slice, Stack, Subs, @@ -43,6 +44,7 @@ from funsor.testing import assert_close, check_funsor, random_tensor assert Binary # flake8 +assert Reduce # flake8 assert Subs # flake8 assert Contraction # flake8 assert Reals # flake8 diff --git a/test/test_typing.py b/test/test_typing.py new file mode 100644 index 000000000..afc4c2f70 --- /dev/null +++ b/test/test_typing.py @@ -0,0 +1,122 @@ +# Copyright Contributors to the Pyro project. +# SPDX-License-Identifier: Apache-2.0 + +import typing + +from funsor.ops import AssociativeOp, Op +from funsor.terms import Cat, Funsor, Number, Reduce, Stack, Variable +from funsor.typing import deep_issubclass + + +def test_deep_issubclass_identity(): + assert deep_issubclass(Reduce, Reduce) + assert deep_issubclass( + Reduce[AssociativeOp, Funsor, frozenset], + Reduce[AssociativeOp, Funsor, frozenset], + ) + + +def test_deep_issubclass_empty(): + assert deep_issubclass(Reduce[AssociativeOp, Funsor, frozenset], Funsor) + assert deep_issubclass(Reduce[AssociativeOp, Funsor, frozenset], Reduce) + assert not deep_issubclass(Funsor, Reduce[AssociativeOp, Funsor, frozenset]) + assert not deep_issubclass(Reduce, Reduce[AssociativeOp, Funsor, frozenset]) + + +def test_deep_issubclass_neither(): + assert not deep_issubclass( + Reduce[AssociativeOp, Reduce[AssociativeOp, Funsor, frozenset], frozenset], + Reduce[Op, Variable, frozenset], + ) + assert not deep_issubclass( + Reduce[Op, Variable, frozenset], + Reduce[AssociativeOp, Reduce[AssociativeOp, Funsor, frozenset], frozenset], + ) + + assert not deep_issubclass( + Stack[str, typing.Tuple[Number, Number]], + Stack[str, typing.Tuple[Number, Reduce]], + ) + assert not deep_issubclass( + Stack[str, typing.Tuple[Number, Reduce]], + Stack[str, typing.Tuple[Number, Number]], + ) + + +def test_deep_issubclass_tuple_internal(): + assert deep_issubclass(Stack[str, typing.Tuple[Number, Number, Number]], Stack) + assert deep_issubclass( + Stack[str, typing.Tuple[Number, Number, Number]], Stack[str, tuple] + ) + assert not deep_issubclass(Stack, Stack[str, typing.Tuple[Number, Number, Number]]) + assert not deep_issubclass( + Stack[str, tuple], Stack[str, typing.Tuple[Number, Number, Number]] + ) + + +def test_deep_issubclass_tuple_finite(): + assert not deep_issubclass( + Stack[str, typing.Tuple[Number, Number]], + Stack[str, typing.Tuple[Number, Reduce]], + ) + + +def test_deep_issubclass_union_internal(): + + assert deep_issubclass( + Reduce[AssociativeOp, typing.Union[Number, Funsor], frozenset], Funsor + ) + assert not deep_issubclass( + Funsor, Reduce[AssociativeOp, typing.Union[Number, Funsor], frozenset] + ) + + assert deep_issubclass( + Reduce[ + AssociativeOp, + typing.Union[Number, Stack[str, typing.Tuple[Number, Number]]], + frozenset, + ], + Funsor, + ) + assert deep_issubclass( + Reduce[AssociativeOp, typing.Union[Number, Stack], frozenset], + Reduce[Op, Funsor, frozenset], + ) + + assert deep_issubclass( + Reduce[AssociativeOp, typing.Union[Number, Stack], frozenset], + Reduce[AssociativeOp, Funsor, frozenset], + ) + assert not deep_issubclass( + Reduce[AssociativeOp, Funsor, frozenset], + Reduce[AssociativeOp, typing.Union[Number, Stack], frozenset], + ) + assert not deep_issubclass( + Reduce[Op, Funsor, frozenset], + Reduce[AssociativeOp, typing.Union[Number, Stack], frozenset], + ) + + +def test_deep_issubclass_union_internal_multiple(): + assert not deep_issubclass( + Reduce[typing.Union[Op, AssociativeOp], Stack, frozenset], + Reduce[ + AssociativeOp, + typing.Union[ + Stack[str, tuple], + Reduce[AssociativeOp, typing.Union[Cat, Stack], frozenset], + ], + frozenset, + ], + ) + + assert not deep_issubclass( + Reduce[ + AssociativeOp, + typing.Union[ + Stack, Reduce[AssociativeOp, typing.Union[Number, Stack], frozenset] + ], + frozenset, + ], + Reduce[typing.Union[Op, AssociativeOp], Stack, frozenset], + ) From cb5eea5841bf6513114495622ec6cf237be87244 Mon Sep 17 00:00:00 2001 From: Eli Date: Tue, 16 Feb 2021 20:31:43 -0500 Subject: [PATCH 41/66] get_origin --- funsor/registry.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/funsor/registry.py b/funsor/registry.py index 0f4e8d5f7..707b8da4d 100644 --- a/funsor/registry.py +++ b/funsor/registry.py @@ -91,7 +91,7 @@ def decorator(fn): return decorator def __contains__(self, key): - return key in self.registry + return get_origin(key) in self.registry def __getitem__(self, key): key = get_origin(key) From 107aebec270d364cd565cf6e7298e953bc4dec51 Mon Sep 17 00:00:00 2001 From: Eli Date: Tue, 16 Feb 2021 20:36:30 -0500 Subject: [PATCH 42/66] more get_origin --- funsor/terms.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/funsor/terms.py b/funsor/terms.py index 065a2aa4d..fa9c78288 100644 --- a/funsor/terms.py +++ b/funsor/terms.py @@ -28,7 +28,7 @@ from funsor.interpreter import PatternMissingError, interpret from funsor.ops import AssociativeOp, GetitemOp, Op from funsor.syntax import INFIX_OPERATORS, PREFIX_OPERATORS -from funsor.typing import GenericTypeMeta, Variadic, deep_type +from funsor.typing import GenericTypeMeta, Variadic, deep_type, get_origin from funsor.util import getargspec, lazy_property, pretty, quote from . import instrument, interpreter, ops @@ -106,7 +106,7 @@ def reflect(cls, *args, **kwargs): return cls._cons_cache[cache_key] arg_types = tuple(map(deep_type, args)) - cls_specific = (cls.__origin__ if cls.__args__ else cls)[arg_types] + cls_specific = get_origin(cls)[arg_types] result = super(FunsorMeta, cls_specific).__call__(*args) result._ast_values = args @@ -114,7 +114,7 @@ def reflect(cls, *args, **kwargs): size, depth, width = _get_ast_stats(result) instrument.COUNTERS["ast_size"][size] += 1 instrument.COUNTERS["ast_depth"][depth] += 1 - classname = getattr(cls, "__origin__", cls).__name__ + classname = get_origin(cls).__name__ instrument.COUNTERS["funsor"][classname] += 1 instrument.COUNTERS[classname][width] += 1 From b6d2d46550f2af7f5a84fad113a506660aa64a44 Mon Sep 17 00:00:00 2001 From: Eli Date: Wed, 17 Feb 2021 16:02:25 -0500 Subject: [PATCH 43/66] add more tests --- test/test_typing.py | 228 +++++++++++++++++++++++++++++++++++++------- 1 file changed, 195 insertions(+), 33 deletions(-) diff --git a/test/test_typing.py b/test/test_typing.py index afc4c2f70..fed3e3956 100644 --- a/test/test_typing.py +++ b/test/test_typing.py @@ -1,11 +1,23 @@ # Copyright Contributors to the Pyro project. # SPDX-License-Identifier: Apache-2.0 -import typing +from typing import Any, FrozenSet, Optional, Tuple, Union + +from multipledispatch import dispatch from funsor.ops import AssociativeOp, Op +from funsor.registry import PartialDispatcher from funsor.terms import Cat, Funsor, Number, Reduce, Stack, Variable -from funsor.typing import deep_issubclass +from funsor.typing import ( + GenericTypeMeta, + Variadic, + deep_isinstance, + deep_issubclass, + deep_type, + get_args, + get_origin, + typing_wrap, +) def test_deep_issubclass_identity(): @@ -14,16 +26,17 @@ def test_deep_issubclass_identity(): Reduce[AssociativeOp, Funsor, frozenset], Reduce[AssociativeOp, Funsor, frozenset], ) + assert deep_issubclass(Tuple, Tuple) -def test_deep_issubclass_empty(): +def test_deep_issubclass_generic_empty(): assert deep_issubclass(Reduce[AssociativeOp, Funsor, frozenset], Funsor) assert deep_issubclass(Reduce[AssociativeOp, Funsor, frozenset], Reduce) assert not deep_issubclass(Funsor, Reduce[AssociativeOp, Funsor, frozenset]) assert not deep_issubclass(Reduce, Reduce[AssociativeOp, Funsor, frozenset]) -def test_deep_issubclass_neither(): +def test_deep_issubclass_generic_neither(): assert not deep_issubclass( Reduce[AssociativeOp, Reduce[AssociativeOp, Funsor, frozenset], frozenset], Reduce[Op, Variable, frozenset], @@ -34,77 +47,72 @@ def test_deep_issubclass_neither(): ) assert not deep_issubclass( - Stack[str, typing.Tuple[Number, Number]], - Stack[str, typing.Tuple[Number, Reduce]], + Stack[str, Tuple[Number, Number]], + Stack[str, Tuple[Number, Reduce]], ) assert not deep_issubclass( - Stack[str, typing.Tuple[Number, Reduce]], - Stack[str, typing.Tuple[Number, Number]], + Stack[str, Tuple[Number, Reduce]], + Stack[str, Tuple[Number, Number]], ) def test_deep_issubclass_tuple_internal(): - assert deep_issubclass(Stack[str, typing.Tuple[Number, Number, Number]], Stack) - assert deep_issubclass( - Stack[str, typing.Tuple[Number, Number, Number]], Stack[str, tuple] - ) - assert not deep_issubclass(Stack, Stack[str, typing.Tuple[Number, Number, Number]]) + assert deep_issubclass(Stack[str, Tuple[Number, Number, Number]], Stack) + assert deep_issubclass(Stack[str, Tuple[Number, Number, Number]], Stack[str, tuple]) + assert not deep_issubclass(Stack, Stack[str, Tuple[Number, Number, Number]]) assert not deep_issubclass( - Stack[str, tuple], Stack[str, typing.Tuple[Number, Number, Number]] + Stack[str, tuple], Stack[str, Tuple[Number, Number, Number]] ) - - -def test_deep_issubclass_tuple_finite(): assert not deep_issubclass( - Stack[str, typing.Tuple[Number, Number]], - Stack[str, typing.Tuple[Number, Reduce]], + Stack[str, Tuple[Number, Number]], + Stack[str, Tuple[Number, Reduce]], ) def test_deep_issubclass_union_internal(): assert deep_issubclass( - Reduce[AssociativeOp, typing.Union[Number, Funsor], frozenset], Funsor + Reduce[AssociativeOp, Union[Number, Funsor], frozenset], Funsor ) assert not deep_issubclass( - Funsor, Reduce[AssociativeOp, typing.Union[Number, Funsor], frozenset] + Funsor, Reduce[AssociativeOp, Union[Number, Funsor], frozenset] ) assert deep_issubclass( Reduce[ AssociativeOp, - typing.Union[Number, Stack[str, typing.Tuple[Number, Number]]], + Union[Number, Stack[str, Tuple[Number, Number]]], frozenset, ], Funsor, ) assert deep_issubclass( - Reduce[AssociativeOp, typing.Union[Number, Stack], frozenset], + Reduce[AssociativeOp, Union[Number, Stack], frozenset], Reduce[Op, Funsor, frozenset], ) assert deep_issubclass( - Reduce[AssociativeOp, typing.Union[Number, Stack], frozenset], + Reduce[AssociativeOp, Union[Number, Stack], frozenset], Reduce[AssociativeOp, Funsor, frozenset], ) assert not deep_issubclass( Reduce[AssociativeOp, Funsor, frozenset], - Reduce[AssociativeOp, typing.Union[Number, Stack], frozenset], + Reduce[AssociativeOp, Union[Number, Stack], frozenset], ) assert not deep_issubclass( Reduce[Op, Funsor, frozenset], - Reduce[AssociativeOp, typing.Union[Number, Stack], frozenset], + Reduce[AssociativeOp, Union[Number, Stack], frozenset], ) def test_deep_issubclass_union_internal_multiple(): assert not deep_issubclass( - Reduce[typing.Union[Op, AssociativeOp], Stack, frozenset], + Reduce[Union[Op, AssociativeOp], Stack, frozenset], Reduce[ AssociativeOp, - typing.Union[ + Union[ Stack[str, tuple], - Reduce[AssociativeOp, typing.Union[Cat, Stack], frozenset], + Reduce[AssociativeOp, Union[Cat, Stack], frozenset], ], frozenset, ], @@ -113,10 +121,164 @@ def test_deep_issubclass_union_internal_multiple(): assert not deep_issubclass( Reduce[ AssociativeOp, - typing.Union[ - Stack, Reduce[AssociativeOp, typing.Union[Number, Stack], frozenset] - ], + Union[Stack, Reduce[AssociativeOp, Union[Number, Stack], frozenset]], frozenset, ], - Reduce[typing.Union[Op, AssociativeOp], Stack, frozenset], + Reduce[Union[Op, AssociativeOp], Stack, frozenset], + ) + + +def test_deep_issubclass_tuple_variadic(): + + assert deep_issubclass(Tuple[int, ...], Tuple) + assert deep_issubclass(Tuple[int], Tuple[int, ...]) + assert deep_issubclass(Tuple[int, int], Tuple[int, ...]) + + assert deep_issubclass(Tuple[Reduce, ...], Tuple[Funsor, ...]) + assert not deep_issubclass(Tuple[Funsor, ...], Tuple[Reduce, ...]) + + assert deep_issubclass(Tuple[Reduce], Tuple[Funsor, ...]) + assert not deep_issubclass(Tuple[Funsor], Tuple[Reduce, ...]) + + assert not deep_issubclass(Tuple[str], Tuple[int, ...]) + assert not deep_issubclass(Tuple[int, str], Tuple[int, ...]) + + assert deep_issubclass(Tuple[Tuple[int, str]], Tuple[Tuple, ...]) + assert deep_issubclass(Tuple[Tuple[int, str], Tuple[int, str]], Tuple[Tuple, ...]) + assert deep_issubclass(Tuple[Tuple[int, str]], Tuple[Tuple[int, str], ...]) + assert deep_issubclass(Tuple[Tuple[int, str], ...], Tuple[Tuple[int, str], ...]) + assert deep_issubclass( + Tuple[Tuple[int, str], Tuple[int, str]], Tuple[Tuple[int, str], ...] ) + + +def test_generic_type_cons_hash(): + class A(metaclass=GenericTypeMeta): + pass + + class B(metaclass=GenericTypeMeta): + pass + + assert A[int] is A[int] + assert A[float] is not A[int] + assert B[int] is not A[int] + assert B[A[int], int] is B[A[int], int] + + assert FrozenSet[int] is FrozenSet[int] + assert FrozenSet[B[int]] is FrozenSet[B[int]] + + assert Tuple[B[int, int], ...] is Tuple[B[int, int], ...] + + assert Union[B[int]] is Union[B[int]] + assert Union[B[int], B[int]] is Union[B[int]] + + +def test_get_origin(): + + assert get_origin(Any) is Any + + assert get_origin(Tuple[int]) in (tuple, Tuple) + assert get_origin(FrozenSet[int]) in (frozenset, FrozenSet) + + assert get_origin(Union[int, int, str]) is Union + + assert get_origin(Reduce[AssociativeOp, Funsor, frozenset]) is Reduce + assert get_origin(Reduce) is Reduce + + +def test_get_args(): + assert not get_args(Any) + + assert get_args(Tuple[int]) == (int,) + assert get_args(Tuple[int, ...]) == (int, ...) + assert not get_args(Tuple) + + assert get_args(FrozenSet[int]) == (int,) + + assert int in get_args(Optional[int]) + + assert get_args(Union[int]) == () + assert get_args(Union[int, str]) == (int, str) + assert get_args(Union[int, int, str]) == (int, str) + + assert get_args(Reduce[AssociativeOp, Funsor, frozenset]) == ( + AssociativeOp, + Funsor, + frozenset, + ) + assert not get_args(Reduce) + + +def test_variadic_dispatch_basic(): + @dispatch + def f(*args): + return 1 + + @dispatch(int, int) + def f(*args): + return 2 + + @dispatch(Variadic[int]) + def f(*args): + return 3 + + @dispatch(typing_wrap(Tuple[int, int]), typing_wrap(Tuple[int, int])) + def f(*args): + return 4 + + @dispatch(Variadic[Tuple[int, int]]) + def f(*args): + return 5 + + assert f(1.5) == 1 + assert f(1.5, 1) == 1 + + assert f(1, 1) == 2 + assert f(1) == 3 + assert f(1, 2, 3) == 3 + + assert f((1, 1), (1, 1)) == 4 + assert f((1, 2)) == 5 + assert f((1, 2), (3, 4), (5, 6)) == 5 + + +def test_variadic_dispatch_typing(): + + f = PartialDispatcher("f", default=lambda *args: 0) + + @f.register + def _(a: Any) -> int: + return 0 + + @f.register + def _(a: int, b: int) -> int: + return 2 + + @f.register(Variadic[int]) + def _(*args): + return 3 + + @f.register + def _(a: Tuple[int, int], b: Tuple[int, int]) -> int: + return 4 + + @f.register(Variadic[Tuple[int, int]]) + def _(*args): + return 5 + + @f.register + def _(a: Tuple[int, float], b: Tuple[int, float]) -> int: + return 6 + + assert f(1.5) == 1 + assert f(1.5, 1) == 0 + + assert f(1, 1) == 2 + assert f(1) == 3 + assert f(1, 2, 3) == 3 + + assert f((1, 1), (1, 1)) == 4 + assert f((1, 2)) == 5 + assert f((1, 2), (3, 4), (5, 6)) == 5 + + assert f((1, 1.5), (2, 2.5)) == 6 From b773aa4fae5bb9adfbd27c3a8d7496b082e5af39 Mon Sep 17 00:00:00 2001 From: Eli Date: Wed, 17 Feb 2021 16:33:14 -0500 Subject: [PATCH 44/66] fixes for variadic dispatch --- funsor/registry.py | 4 +++- funsor/typing.py | 11 +++++++++-- test/test_typing.py | 22 +++++++++------------- 3 files changed, 21 insertions(+), 16 deletions(-) diff --git a/funsor/registry.py b/funsor/registry.py index 707b8da4d..3cba9d016 100644 --- a/funsor/registry.py +++ b/funsor/registry.py @@ -34,7 +34,9 @@ def add(self, signature, func): return # Handle variadic types - signature = (Variadic[tp] if isinstance(tp, list) else tp for tp in signature) + signature = ( + Variadic[tuple(tp)] if isinstance(tp, list) else tp for tp in signature + ) signature = tuple(map(typing_wrap, signature)) super().add(signature, func) diff --git a/funsor/typing.py b/funsor/typing.py index 8188b2daf..42d114568 100644 --- a/funsor/typing.py +++ b/funsor/typing.py @@ -38,6 +38,15 @@ def deep_issubclass(subcls, cls): """replaces issubclass()""" # return pytypes.is_subtype(subcls, cls) + # handle unpacking + if isinstance(subcls, _RuntimeSubclassCheckMeta): + try: + return deep_issubclass(subcls.__args__[0], cls) + except TypeError as e: + if e.args[0] == "issubclass() arg 1 must be a class": + return False + raise + if get_origin(cls) is typing.Union: return any(deep_issubclass(subcls, arg) for arg in get_args(cls)) @@ -202,8 +211,6 @@ def __call__(cls, tp): return tp if isinstance(tp, GenericTypeMeta) or isvariadic(tp) else cls[tp] def __subclasscheck__(cls, subcls): - if isinstance(subcls, _RuntimeSubclassCheckMeta): - subcls = subcls.__args__[0] return deep_issubclass(subcls, cls.__args__[0]) diff --git a/test/test_typing.py b/test/test_typing.py index fed3e3956..a5053a200 100644 --- a/test/test_typing.py +++ b/test/test_typing.py @@ -210,23 +210,23 @@ def test_get_args(): def test_variadic_dispatch_basic(): - @dispatch + @dispatch(Variadic[object]) def f(*args): return 1 @dispatch(int, int) - def f(*args): + def f(a, b): return 2 @dispatch(Variadic[int]) def f(*args): return 3 - @dispatch(typing_wrap(Tuple[int, int]), typing_wrap(Tuple[int, int])) - def f(*args): + @dispatch(typing_wrap(Tuple), typing_wrap(Tuple)) + def f(a, b): return 4 - @dispatch(Variadic[Tuple[int, int]]) + @dispatch(Variadic[Tuple]) def f(*args): return 5 @@ -244,17 +244,13 @@ def f(*args): def test_variadic_dispatch_typing(): - f = PartialDispatcher("f", default=lambda *args: 0) - - @f.register - def _(a: Any) -> int: - return 0 + f = PartialDispatcher("f", default=lambda *args: 1) @f.register def _(a: int, b: int) -> int: return 2 - @f.register(Variadic[int]) + @f.register([int]) def _(*args): return 3 @@ -262,7 +258,7 @@ def _(*args): def _(a: Tuple[int, int], b: Tuple[int, int]) -> int: return 4 - @f.register(Variadic[Tuple[int, int]]) + @f.register([Tuple[int, int]]) # list syntax for variadic def _(*args): return 5 @@ -271,7 +267,7 @@ def _(a: Tuple[int, float], b: Tuple[int, float]) -> int: return 6 assert f(1.5) == 1 - assert f(1.5, 1) == 0 + assert f(1.5, 1) == 1 assert f(1, 1) == 2 assert f(1) == 3 From 37d7022249811e6fcd704e81a5aaff65635163da Mon Sep 17 00:00:00 2001 From: Eli Date: Wed, 17 Feb 2021 16:41:36 -0500 Subject: [PATCH 45/66] add another dispatch test with no variadic patterns --- test/test_typing.py | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/test/test_typing.py b/test/test_typing.py index a5053a200..fe636a9e9 100644 --- a/test/test_typing.py +++ b/test/test_typing.py @@ -242,6 +242,38 @@ def f(*args): assert f((1, 2), (3, 4), (5, 6)) == 5 +def test_dispatch_typing(): + + f = PartialDispatcher("f", default=lambda *args: 1) + + @f.register + def _(a: int, b: int) -> int: + return 2 + + @f.register + def _(a: Tuple[int, int], b: Tuple[int, int]) -> int: + return 3 + + @f.register + def _(a: Tuple[int, ...], b: Tuple[int, int]) -> int: + return 4 + + @f.register + def _(a: Tuple[int, float], b: Tuple[int, float]) -> int: + return 5 + + assert f(1.5) == 1 + assert f(1.5, 1) == 1 + + assert f(1, 1) == 2 + + assert f((1, 1), (1, 1)) == 3 + assert f((1, 2)) == 0 + assert f((1, 2, 3), (4, 5)) == 4 + + assert f((1, 1.5), (2, 2.5)) == 5 + + def test_variadic_dispatch_typing(): f = PartialDispatcher("f", default=lambda *args: 1) From 9a73f561bdf0db250f2f16a648f0756b508a4859 Mon Sep 17 00:00:00 2001 From: Eli Date: Fri, 19 Feb 2021 09:32:30 -0500 Subject: [PATCH 46/66] add python 3.7,8,9 stages to github actions --- .github/workflows/ci.yml | 78 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 78 insertions(+) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 2c855bf38..870e899fa 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -87,3 +87,81 @@ jobs: - name: Run test run: | CI=1 FUNSOR_BACKEND=jax make test + + + numpy37: + + runs-on: ubuntu-latest + strategy: + matrix: + python-version: [3.7] + + steps: + - uses: actions/checkout@v2 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + sudo apt install -y pandoc + python -m pip install --upgrade pip + # Keep track of pyro-api master branch + pip install https://github.com/pyro-ppl/pyro-api/archive/master.zip + pip install .[test] + pip freeze + - name: Run test + run: | + make test + + + numpy38: + + runs-on: ubuntu-latest + strategy: + matrix: + python-version: [3.8] + + steps: + - uses: actions/checkout@v2 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + sudo apt install -y pandoc + python -m pip install --upgrade pip + # Keep track of pyro-api master branch + pip install https://github.com/pyro-ppl/pyro-api/archive/master.zip + pip install .[test] + pip freeze + - name: Run test + run: | + make test + + + numpy39: + + runs-on: ubuntu-latest + strategy: + matrix: + python-version: [3.9] + + steps: + - uses: actions/checkout@v2 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + sudo apt install -y pandoc + python -m pip install --upgrade pip + # Keep track of pyro-api master branch + pip install https://github.com/pyro-ppl/pyro-api/archive/master.zip + pip install .[test] + pip freeze + - name: Run test + run: | + make test From d8e048713c638cc330263afb69e381e4c1cc265d Mon Sep 17 00:00:00 2001 From: Eli Date: Fri, 19 Feb 2021 11:20:59 -0500 Subject: [PATCH 47/66] split up deep_issubclass --- funsor/typing.py | 143 ++++++++++++++++++++++++++------------------ test/test_typing.py | 18 ++++++ 2 files changed, 103 insertions(+), 58 deletions(-) diff --git a/funsor/typing.py b/funsor/typing.py index 42d114568..cc7002616 100644 --- a/funsor/typing.py +++ b/funsor/typing.py @@ -1,6 +1,7 @@ # Copyright Contributors to the Pyro project. # SPDX-License-Identifier: Apache-2.0 +import collections import functools import sys import typing @@ -15,22 +16,88 @@ ################################# -def deep_isinstance(obj, cls): - """replaces isinstance()""" - # return pytypes.is_of_type(obj, cls) - return deep_issubclass(deep_type(obj), cls) +@functools.singledispatch +def deep_type(obj): + return type(obj) -def deep_type(obj): - """replaces type()""" - # return pytypes.deep_type(obj) - if isinstance(obj, tuple): - return typing.Tuple[tuple(map(deep_type, obj))] if obj else typing.Tuple +@deep_type.register(tuple) +def _deep_type_tuple(obj): + return typing.Tuple[tuple(map(deep_type, obj))] if obj else typing.Tuple - if isinstance(obj, frozenset): - return typing.FrozenSet[next(map(deep_type, obj))] if obj else typing.FrozenSet - return type(obj) +@deep_type.register(frozenset) +def _deep_type_frozenset(obj): + return typing.FrozenSet[next(map(deep_type, obj))] if obj else typing.FrozenSet + + +_subclasscheck_registry = collections.defaultdict(lambda: issubclass) + + +def register_issubclass(cls): + def _fn(fn): + _subclasscheck_registry[cls] = fn + return fn + + return _fn + + +register_issubclass(typing.Any)(lambda a, b: True) + + +@register_issubclass(typing.Union) +def _deep_issubclass_union(subcls, cls): + return any(deep_issubclass(arg, cls) for arg in get_args(subcls)) + + +@register_issubclass(frozenset) +@register_issubclass(typing.FrozenSet) +def _deep_issubclass_frozenset(subcls, cls): + + if not issubclass(get_origin(subcls), frozenset): + return False + + cls_args, subcls_args = get_args(cls), get_args(subcls) + + if not cls_args: + return True + + if not subcls_args: + return cls_args[0] is typing.Any + + return len(subcls_args) == len(cls_args) == 1 and all( + deep_issubclass(a, b) for a, b in zip(subcls_args, cls_args) + ) + + +@register_issubclass(tuple) +@register_issubclass(typing.Tuple) +def _deep_issubclass_tuple(subcls, cls): + + if not issubclass(get_origin(subcls), get_origin(cls)): + return False + + cls_args, subcls_args = get_args(cls), get_args(subcls) + + if not cls_args: # cls is base Tuple + return True + + if not subcls_args: + return cls_args[0] is typing.Any + + if cls_args[-1] is Ellipsis: # cls variadic + if subcls_args[-1] is Ellipsis: # both variadic + return deep_issubclass(subcls_args[0], cls_args[0]) + return all(deep_issubclass(a, cls_args[0]) for a in subcls_args) + + if subcls_args[-1] is Ellipsis: # only subcls variadic + # issubclass(Tuple[A, ...], Tuple[X, Y]) == False + return False + + # neither variadic + return len(cls_args) == len(subcls_args) and all( + deep_issubclass(a, b) for a, b in zip(subcls_args, cls_args) + ) @functools.lru_cache(maxsize=None) @@ -47,59 +114,19 @@ def deep_issubclass(subcls, cls): return False raise - if get_origin(cls) is typing.Union: - return any(deep_issubclass(subcls, arg) for arg in get_args(cls)) - if get_origin(subcls) is typing.Union: return all(deep_issubclass(arg, cls) for arg in get_args(subcls)) - if cls is typing.Any: - return True - if subcls is typing.Any: - return False - - if issubclass(get_origin(cls), typing.FrozenSet): - - if not issubclass(get_origin(subcls), get_origin(cls)): - return False - - if not get_args(cls): - return True + return cls is typing.Any - if not get_args(subcls): - return get_args(cls)[0] is typing.Any + return _subclasscheck_registry[get_origin(cls)](subcls, cls) - return len(get_args(subcls)) == len(get_args(cls)) == 1 and deep_issubclass( - get_args(subcls)[0], get_args(cls)[0] - ) - - if issubclass(get_origin(cls), typing.Tuple): - - if not issubclass(get_origin(subcls), get_origin(cls)): - return False - - if not get_args(cls): # cls is base Tuple - return True - - if not get_args(subcls): - return get_args(cls)[0] is typing.Any - - if get_args(cls)[-1] is Ellipsis: # cls variadic - if get_args(subcls)[-1] is Ellipsis: # both variadic - return deep_issubclass(get_args(subcls)[0], get_args(cls)[0]) - return all(deep_issubclass(a, get_args(cls)[0]) for a in get_args(subcls)) - if get_args(subcls)[-1] is Ellipsis: # only subcls variadic - # issubclass(Tuple[A, ...], Tuple[X, Y]) == False - return False - - # neither variadic - return len(get_args(cls)) == len(get_args(subcls)) and all( - deep_issubclass(a, b) for a, b in zip(get_args(subcls), get_args(cls)) - ) - - return issubclass(subcls, cls) +def deep_isinstance(obj, cls): + """replaces isinstance()""" + # return pytypes.is_of_type(obj, cls) + return deep_issubclass(deep_type(obj), cls) def _type_to_typing(tp): diff --git a/test/test_typing.py b/test/test_typing.py index fe636a9e9..3643b82fb 100644 --- a/test/test_typing.py +++ b/test/test_typing.py @@ -310,3 +310,21 @@ def _(a: Tuple[int, float], b: Tuple[int, float]) -> int: assert f((1, 2), (3, 4), (5, 6)) == 5 assert f((1, 1.5), (2, 2.5)) == 6 + + +def test_deep_type_simple(): + + x1 = (1, 1.5, "a") + expected_type1 = Tuple[int, float, str] + assert deep_type(x1) is expected_type1 + assert deep_isinstance(x1, expected_type1) + + x2 = frozenset(["a", "b"]) + expected_type2 = FrozenSet[str] + assert deep_type(x2) is expected_type2 + assert deep_isinstance(x2, expected_type2) + + x3 = (1, (2, 3)) + expected_type3 = Tuple[int, Tuple[int, int]] + assert deep_type(x3) is expected_type3 + assert deep_isinstance(x3, expected_type3) From 90566d149150cf7e71400f1dcb2cec7b2aed85ee Mon Sep 17 00:00:00 2001 From: Eli Date: Fri, 19 Feb 2021 11:29:44 -0500 Subject: [PATCH 48/66] dont use defaultdict --- funsor/typing.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/funsor/typing.py b/funsor/typing.py index cc7002616..7df260dd7 100644 --- a/funsor/typing.py +++ b/funsor/typing.py @@ -1,7 +1,6 @@ # Copyright Contributors to the Pyro project. # SPDX-License-Identifier: Apache-2.0 -import collections import functools import sys import typing @@ -31,10 +30,15 @@ def _deep_type_frozenset(obj): return typing.FrozenSet[next(map(deep_type, obj))] if obj else typing.FrozenSet -_subclasscheck_registry = collections.defaultdict(lambda: issubclass) +_subclasscheck_registry = {} def register_issubclass(cls): + """ + Decorator for registering a custom ``__subclasscheck__`` method for ``cls``, + for use in pattern matching with :func:`deep_issubclass`. + """ + def _fn(fn): _subclasscheck_registry[cls] = fn return fn @@ -46,13 +50,13 @@ def _fn(fn): @register_issubclass(typing.Union) -def _deep_issubclass_union(subcls, cls): +def _subclasscheck_union(cls, subcls): return any(deep_issubclass(arg, cls) for arg in get_args(subcls)) @register_issubclass(frozenset) @register_issubclass(typing.FrozenSet) -def _deep_issubclass_frozenset(subcls, cls): +def _subclasscheck_frozenset(cls, subcls): if not issubclass(get_origin(subcls), frozenset): return False @@ -72,7 +76,7 @@ def _deep_issubclass_frozenset(subcls, cls): @register_issubclass(tuple) @register_issubclass(typing.Tuple) -def _deep_issubclass_tuple(subcls, cls): +def _subclasscheck_tuple(cls, subcls): if not issubclass(get_origin(subcls), get_origin(cls)): return False @@ -120,7 +124,10 @@ def deep_issubclass(subcls, cls): if subcls is typing.Any: return cls is typing.Any - return _subclasscheck_registry[get_origin(cls)](subcls, cls) + try: + return _subclasscheck_registry[get_origin(cls)](subcls, cls) + except KeyError: + return issubclass(subcls, cls) def deep_isinstance(obj, cls): From 277d34c627f49a3807ac55e6d61ed7cfe9a182ed Mon Sep 17 00:00:00 2001 From: Eli Date: Fri, 19 Feb 2021 11:30:31 -0500 Subject: [PATCH 49/66] typo --- funsor/typing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/funsor/typing.py b/funsor/typing.py index 7df260dd7..5016d4e65 100644 --- a/funsor/typing.py +++ b/funsor/typing.py @@ -125,7 +125,7 @@ def deep_issubclass(subcls, cls): return cls is typing.Any try: - return _subclasscheck_registry[get_origin(cls)](subcls, cls) + return _subclasscheck_registry[get_origin(cls)](cls, subcls) except KeyError: return issubclass(subcls, cls) From d54703c675f67c1a9c541360c755d8c23345c933 Mon Sep 17 00:00:00 2001 From: Eli Date: Fri, 19 Feb 2021 11:48:17 -0500 Subject: [PATCH 50/66] register_subclasscheck --- funsor/typing.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/funsor/typing.py b/funsor/typing.py index 5016d4e65..1d0f2fd37 100644 --- a/funsor/typing.py +++ b/funsor/typing.py @@ -33,7 +33,7 @@ def _deep_type_frozenset(obj): _subclasscheck_registry = {} -def register_issubclass(cls): +def register_subclasscheck(cls): """ Decorator for registering a custom ``__subclasscheck__`` method for ``cls``, for use in pattern matching with :func:`deep_issubclass`. @@ -46,16 +46,16 @@ def _fn(fn): return _fn -register_issubclass(typing.Any)(lambda a, b: True) +register_subclasscheck(typing.Any)(lambda a, b: True) -@register_issubclass(typing.Union) +@register_subclasscheck(typing.Union) def _subclasscheck_union(cls, subcls): return any(deep_issubclass(arg, cls) for arg in get_args(subcls)) -@register_issubclass(frozenset) -@register_issubclass(typing.FrozenSet) +@register_subclasscheck(frozenset) +@register_subclasscheck(typing.FrozenSet) def _subclasscheck_frozenset(cls, subcls): if not issubclass(get_origin(subcls), frozenset): @@ -74,8 +74,8 @@ def _subclasscheck_frozenset(cls, subcls): ) -@register_issubclass(tuple) -@register_issubclass(typing.Tuple) +@register_subclasscheck(tuple) +@register_subclasscheck(typing.Tuple) def _subclasscheck_tuple(cls, subcls): if not issubclass(get_origin(subcls), get_origin(cls)): From a5d64d4d18529eb4caf3912d2e481c9967ea9c6f Mon Sep 17 00:00:00 2001 From: Eli Date: Fri, 19 Feb 2021 12:23:57 -0500 Subject: [PATCH 51/66] fix tests --- funsor/registry.py | 2 +- funsor/typing.py | 9 ++++-- test/test_typing.py | 71 ++++++++++++++++++++++++++------------------- 3 files changed, 49 insertions(+), 33 deletions(-) diff --git a/funsor/registry.py b/funsor/registry.py index 3cba9d016..2f84e4c00 100644 --- a/funsor/registry.py +++ b/funsor/registry.py @@ -17,7 +17,7 @@ def __init__(self, name, default=None): self.default = default if default is None else PartialDefault(default) super().__init__(name) if default is not None: - self.add((Variadic[object],), self.default) + self.add(([object],), self.default) def add(self, signature, func): diff --git a/funsor/typing.py b/funsor/typing.py index 1d0f2fd37..2b4e28d49 100644 --- a/funsor/typing.py +++ b/funsor/typing.py @@ -27,7 +27,12 @@ def _deep_type_tuple(obj): @deep_type.register(frozenset) def _deep_type_frozenset(obj): - return typing.FrozenSet[next(map(deep_type, obj))] if obj else typing.FrozenSet + if not obj: + return typing.FrozenSet + tp = deep_type(next(iter(obj))) + if not all(deep_isinstance(x, tp) for x in obj): + raise NotImplementedError(f"TODO handle inhomogeneous frozensets: {str(obj)}") + return typing.FrozenSet[tp] _subclasscheck_registry = {} @@ -115,7 +120,7 @@ def deep_issubclass(subcls, cls): return deep_issubclass(subcls.__args__[0], cls) except TypeError as e: if e.args[0] == "issubclass() arg 1 must be a class": - return False + return deep_issubclass(get_origin(subcls.__args__[0]), cls) raise if get_origin(subcls) is typing.Union: diff --git a/test/test_typing.py b/test/test_typing.py index 3643b82fb..a54b52130 100644 --- a/test/test_typing.py +++ b/test/test_typing.py @@ -3,6 +3,7 @@ from typing import Any, FrozenSet, Optional, Tuple, Union +import pytest from multipledispatch import dispatch from funsor.ops import AssociativeOp, Op @@ -134,6 +135,9 @@ def test_deep_issubclass_tuple_variadic(): assert deep_issubclass(Tuple[int], Tuple[int, ...]) assert deep_issubclass(Tuple[int, int], Tuple[int, ...]) + assert not deep_issubclass(Tuple[int, ...], Tuple[int]) + assert not deep_issubclass(Tuple[int, ...], Tuple[int, int]) + assert deep_issubclass(Tuple[Reduce, ...], Tuple[Funsor, ...]) assert not deep_issubclass(Tuple[Funsor, ...], Tuple[Reduce, ...]) @@ -152,6 +156,31 @@ def test_deep_issubclass_tuple_variadic(): ) +def test_deep_type_tuple(): + + x1 = (1, 1.5, "a") + expected_type1 = Tuple[int, float, str] + assert deep_type(x1) is expected_type1 + assert deep_isinstance(x1, expected_type1) + + x2 = (1, (2, 3)) + expected_type2 = Tuple[int, Tuple[int, int]] + assert deep_type(x2) is expected_type2 + assert deep_isinstance(x2, expected_type2) + + +def test_deep_type_frozenset(): + + x1 = frozenset(["a", "b"]) + expected_type1 = FrozenSet[str] + assert deep_type(x1) is expected_type1 + assert deep_isinstance(x1, expected_type1) + + with pytest.raises(NotImplementedError): + x2 = frozenset(["a", 1]) + deep_type(x2) + + def test_generic_type_cons_hash(): class A(metaclass=GenericTypeMeta): pass @@ -246,20 +275,20 @@ def test_dispatch_typing(): f = PartialDispatcher("f", default=lambda *args: 1) - @f.register - def _(a: int, b: int) -> int: + @f.register() + def f2(a: int, b: int) -> int: return 2 - @f.register - def _(a: Tuple[int, int], b: Tuple[int, int]) -> int: + @f.register() + def f3(a: Tuple[int, int], b: Tuple[int, int]) -> int: return 3 - @f.register - def _(a: Tuple[int, ...], b: Tuple[int, int]) -> int: + @f.register() + def f4(a: Tuple[int, ...], b: Tuple[int, int]) -> int: return 4 - @f.register - def _(a: Tuple[int, float], b: Tuple[int, float]) -> int: + @f.register() + def f5(a: Tuple[int, float], b: Tuple[int, float]) -> int: return 5 assert f(1.5) == 1 @@ -268,7 +297,7 @@ def _(a: Tuple[int, float], b: Tuple[int, float]) -> int: assert f(1, 1) == 2 assert f((1, 1), (1, 1)) == 3 - assert f((1, 2)) == 0 + assert f((1, 2)) == 1 assert f((1, 2, 3), (4, 5)) == 4 assert f((1, 1.5), (2, 2.5)) == 5 @@ -278,7 +307,7 @@ def test_variadic_dispatch_typing(): f = PartialDispatcher("f", default=lambda *args: 1) - @f.register + @f.register() def _(a: int, b: int) -> int: return 2 @@ -286,7 +315,7 @@ def _(a: int, b: int) -> int: def _(*args): return 3 - @f.register + @f.register() def _(a: Tuple[int, int], b: Tuple[int, int]) -> int: return 4 @@ -294,7 +323,7 @@ def _(a: Tuple[int, int], b: Tuple[int, int]) -> int: def _(*args): return 5 - @f.register + @f.register() def _(a: Tuple[int, float], b: Tuple[int, float]) -> int: return 6 @@ -310,21 +339,3 @@ def _(a: Tuple[int, float], b: Tuple[int, float]) -> int: assert f((1, 2), (3, 4), (5, 6)) == 5 assert f((1, 1.5), (2, 2.5)) == 6 - - -def test_deep_type_simple(): - - x1 = (1, 1.5, "a") - expected_type1 = Tuple[int, float, str] - assert deep_type(x1) is expected_type1 - assert deep_isinstance(x1, expected_type1) - - x2 = frozenset(["a", "b"]) - expected_type2 = FrozenSet[str] - assert deep_type(x2) is expected_type2 - assert deep_isinstance(x2, expected_type2) - - x3 = (1, (2, 3)) - expected_type3 = Tuple[int, Tuple[int, int]] - assert deep_type(x3) is expected_type3 - assert deep_isinstance(x3, expected_type3) From c62e991f0ad2e7c2583c812b4d0bff5be66971c9 Mon Sep 17 00:00:00 2001 From: Eli Date: Fri, 19 Feb 2021 12:37:11 -0500 Subject: [PATCH 52/66] add test for get_type_hints --- funsor/registry.py | 7 ++++--- test/test_typing.py | 16 ++++++++++++++++ 2 files changed, 20 insertions(+), 3 deletions(-) diff --git a/funsor/registry.py b/funsor/registry.py index 2f84e4c00..07f81f116 100644 --- a/funsor/registry.py +++ b/funsor/registry.py @@ -5,7 +5,7 @@ from multipledispatch.dispatcher import Dispatcher, expand_tuples -from funsor.typing import Variadic, deep_type, get_origin, typing_wrap +from funsor.typing import Variadic, deep_type, get_origin, get_type_hints, typing_wrap class PartialDispatcher(Dispatcher): @@ -23,9 +23,10 @@ def add(self, signature, func): # Handle annotations if not signature: - annotations = self.get_func_annotations(func) + annotations = get_type_hints(func) + annotations.pop("return") if annotations: - signature = annotations + signature = tuple(annotations.values()) # Handle some union types by expanding at registration time if any(isinstance(typ, tuple) for typ in signature): diff --git a/test/test_typing.py b/test/test_typing.py index a54b52130..4d04e190e 100644 --- a/test/test_typing.py +++ b/test/test_typing.py @@ -17,6 +17,7 @@ deep_type, get_args, get_origin, + get_type_hints, typing_wrap, ) @@ -238,6 +239,21 @@ def test_get_args(): assert not get_args(Reduce) +def test_get_type_hints(): + def f(a: Tuple[int, ...], b: Reduce[AssociativeOp, Funsor, frozenset]) -> int: + return 0 + + hints = get_type_hints(f) + assert hints == { + "a": Tuple[int, ...], + "b": Reduce[AssociativeOp, Funsor, frozenset], + "return": int, + } + + hints.pop("return") + assert "return" in get_type_hints(f) + + def test_variadic_dispatch_basic(): @dispatch(Variadic[object]) def f(*args): From 1b9d7fff5f7c3fb8041daa8c28eeaa312201d4de Mon Sep 17 00:00:00 2001 From: Eli Date: Fri, 19 Feb 2021 12:38:47 -0500 Subject: [PATCH 53/66] handle no return hint --- funsor/registry.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/funsor/registry.py b/funsor/registry.py index 07f81f116..f741c38b4 100644 --- a/funsor/registry.py +++ b/funsor/registry.py @@ -24,7 +24,7 @@ def add(self, signature, func): # Handle annotations if not signature: annotations = get_type_hints(func) - annotations.pop("return") + annotations.pop("return", None) if annotations: signature = tuple(annotations.values()) From 30be93adc4292946e9bdb11122287f3fc439a87c Mon Sep 17 00:00:00 2001 From: Eli Date: Fri, 19 Feb 2021 13:05:01 -0500 Subject: [PATCH 54/66] rename test --- test/test_typing.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/test/test_typing.py b/test/test_typing.py index 4d04e190e..90e916820 100644 --- a/test/test_typing.py +++ b/test/test_typing.py @@ -22,7 +22,7 @@ ) -def test_deep_issubclass_identity(): +def test_deep_issubclass_generic_identity(): assert deep_issubclass(Reduce, Reduce) assert deep_issubclass( Reduce[AssociativeOp, Funsor, frozenset], @@ -58,7 +58,7 @@ def test_deep_issubclass_generic_neither(): ) -def test_deep_issubclass_tuple_internal(): +def test_deep_issubclass_generic_tuple_internal(): assert deep_issubclass(Stack[str, Tuple[Number, Number, Number]], Stack) assert deep_issubclass(Stack[str, Tuple[Number, Number, Number]], Stack[str, tuple]) assert not deep_issubclass(Stack, Stack[str, Tuple[Number, Number, Number]]) @@ -71,7 +71,7 @@ def test_deep_issubclass_tuple_internal(): ) -def test_deep_issubclass_union_internal(): +def test_deep_issubclass_generic_union_internal(): assert deep_issubclass( Reduce[AssociativeOp, Union[Number, Funsor], frozenset], Funsor @@ -107,7 +107,7 @@ def test_deep_issubclass_union_internal(): ) -def test_deep_issubclass_union_internal_multiple(): +def test_deep_issubclass_generic_union_internal_multiple(): assert not deep_issubclass( Reduce[Union[Op, AssociativeOp], Stack, frozenset], Reduce[ From e4c6a1f7d1eeddafc5dc465a92937a0deabcac79 Mon Sep 17 00:00:00 2001 From: Eli Date: Fri, 19 Feb 2021 13:08:24 -0500 Subject: [PATCH 55/66] attempt to use matrix syntax in github actions --- .github/workflows/ci.yml | 80 +--------------------------------------- 1 file changed, 1 insertion(+), 79 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 870e899fa..9b7e6009d 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -12,7 +12,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: [3.6] + python-version: [3.6, 3.7, 3.8, 3.9] steps: - uses: actions/checkout@v2 @@ -87,81 +87,3 @@ jobs: - name: Run test run: | CI=1 FUNSOR_BACKEND=jax make test - - - numpy37: - - runs-on: ubuntu-latest - strategy: - matrix: - python-version: [3.7] - - steps: - - uses: actions/checkout@v2 - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 - with: - python-version: ${{ matrix.python-version }} - - name: Install dependencies - run: | - sudo apt install -y pandoc - python -m pip install --upgrade pip - # Keep track of pyro-api master branch - pip install https://github.com/pyro-ppl/pyro-api/archive/master.zip - pip install .[test] - pip freeze - - name: Run test - run: | - make test - - - numpy38: - - runs-on: ubuntu-latest - strategy: - matrix: - python-version: [3.8] - - steps: - - uses: actions/checkout@v2 - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 - with: - python-version: ${{ matrix.python-version }} - - name: Install dependencies - run: | - sudo apt install -y pandoc - python -m pip install --upgrade pip - # Keep track of pyro-api master branch - pip install https://github.com/pyro-ppl/pyro-api/archive/master.zip - pip install .[test] - pip freeze - - name: Run test - run: | - make test - - - numpy39: - - runs-on: ubuntu-latest - strategy: - matrix: - python-version: [3.9] - - steps: - - uses: actions/checkout@v2 - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 - with: - python-version: ${{ matrix.python-version }} - - name: Install dependencies - run: | - sudo apt install -y pandoc - python -m pip install --upgrade pip - # Keep track of pyro-api master branch - pip install https://github.com/pyro-ppl/pyro-api/archive/master.zip - pip install .[test] - pip freeze - - name: Run test - run: | - make test From adc1ac0259522afc1f88077c06cf10d17799b038 Mon Sep 17 00:00:00 2001 From: Eli Date: Fri, 19 Feb 2021 13:33:38 -0500 Subject: [PATCH 56/66] fix union --- funsor/typing.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/funsor/typing.py b/funsor/typing.py index 2b4e28d49..9d65fcebb 100644 --- a/funsor/typing.py +++ b/funsor/typing.py @@ -30,8 +30,13 @@ def _deep_type_frozenset(obj): if not obj: return typing.FrozenSet tp = deep_type(next(iter(obj))) - if not all(deep_isinstance(x, tp) for x in obj): - raise NotImplementedError(f"TODO handle inhomogeneous frozensets: {str(obj)}") + for x in obj: + if not deep_isinstance(x, tp): + tp = get_origin(tp) + if not deep_isinstance(x, tp): + raise NotImplementedError( + f"TODO handle inhomogeneous frozensets: {str(obj)}" + ) return typing.FrozenSet[tp] @@ -56,7 +61,7 @@ def _fn(fn): @register_subclasscheck(typing.Union) def _subclasscheck_union(cls, subcls): - return any(deep_issubclass(arg, cls) for arg in get_args(subcls)) + return any(deep_issubclass(subcls, arg) for arg in get_args(cls)) @register_subclasscheck(frozenset) From b1e812521dc7dc95294e2fcd8548a1dfbde13bfd Mon Sep 17 00:00:00 2001 From: Eli Date: Fri, 19 Feb 2021 13:43:01 -0500 Subject: [PATCH 57/66] typing.get_type_hints --- funsor/typing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/funsor/typing.py b/funsor/typing.py index 9d65fcebb..1fc56d45e 100644 --- a/funsor/typing.py +++ b/funsor/typing.py @@ -174,7 +174,7 @@ def get_origin(tp): def get_type_hints(obj, globalns=None, localns=None, **kwargs): - return typing_extensions.get_type_hints( + return typing.get_type_hints( obj, globalns=globalns, localns=localns, **kwargs ) From e6710e184d4dc80ec81045bf08b87e758bd492b0 Mon Sep 17 00:00:00 2001 From: Eli Date: Fri, 19 Feb 2021 13:53:38 -0500 Subject: [PATCH 58/66] small optimizations --- funsor/typing.py | 33 +++++++++++++++++---------------- 1 file changed, 17 insertions(+), 16 deletions(-) diff --git a/funsor/typing.py b/funsor/typing.py index 1fc56d45e..190ca0fc0 100644 --- a/funsor/typing.py +++ b/funsor/typing.py @@ -174,9 +174,7 @@ def get_origin(tp): def get_type_hints(obj, globalns=None, localns=None, **kwargs): - return typing.get_type_hints( - obj, globalns=globalns, localns=localns, **kwargs - ) + return typing.get_type_hints(obj, globalns=globalns, localns=localns, **kwargs) ###################################################################### @@ -202,38 +200,41 @@ def __init__(cls, name, bases, dct): def __getitem__(cls, arg_types): if not isinstance(arg_types, tuple): arg_types = (arg_types,) - assert not any( - isvariadic(arg_type) for arg_type in arg_types - ), "nested variadic types not supported" arg_types = tuple(map(_type_to_typing, arg_types)) - if arg_types not in cls._type_cache: + try: + return cls._type_cache[arg_types] + except KeyError: assert not get_args(cls), "cannot subscript a subscripted type {}".format( cls ) + assert not any( + isvariadic(arg_type) for arg_type in arg_types + ), "nested variadic types not supported" new_dct = cls.__dict__.copy() new_dct.update({"__args__": arg_types}) # type(cls) to handle GenericTypeMeta subclasses - cls._type_cache[arg_types] = type(cls)(cls.__name__, (cls,), new_dct) - return cls._type_cache[arg_types] + result = type(cls)(cls.__name__, (cls,), new_dct) + cls._type_cache[arg_types] = result + return result def __subclasscheck__(cls, subcls): # issubclass(subcls, cls) if cls is subcls: return True + cls_origin = get_origin(cls) if not isinstance(subcls, GenericTypeMeta): - return super(GenericTypeMeta, get_origin(cls)).__subclasscheck__(subcls) + return super(GenericTypeMeta, cls_origin).__subclasscheck__(subcls) - if not super(GenericTypeMeta, get_origin(cls)).__subclasscheck__( - get_origin(subcls) - ): + if not super(GenericTypeMeta, cls_origin).__subclasscheck__(get_origin(subcls)): return False - if len(get_args(cls)) != len(get_args(subcls)): - return len(get_args(cls)) == 0 + cls_args, subcls_args = get_args(cls), get_args(subcls) + if len(cls_args) != len(subcls_args): + return len(cls_args) == 0 return all( deep_issubclass(_type_to_typing(ps), _type_to_typing(pc)) - for ps, pc in zip(get_args(subcls), get_args(cls)) + for ps, pc in zip(subcls_args, cls_args) ) def __repr__(cls): From 63d2c72f91845147ac755a7ef3861db53f02984e Mon Sep 17 00:00:00 2001 From: Eli Date: Fri, 19 Feb 2021 15:38:16 -0500 Subject: [PATCH 59/66] add some documentation --- funsor/typing.py | 45 +++++++++++++++++++++++++++++++++++++++------ 1 file changed, 39 insertions(+), 6 deletions(-) diff --git a/funsor/typing.py b/funsor/typing.py index 190ca0fc0..d8c743bfb 100644 --- a/funsor/typing.py +++ b/funsor/typing.py @@ -17,6 +17,10 @@ @functools.singledispatch def deep_type(obj): + """ + An enhanced version of :func:`type`. + """ + # compare to pytypes.deep_type(obj) return type(obj) @@ -45,8 +49,12 @@ def _deep_type_frozenset(obj): def register_subclasscheck(cls): """ - Decorator for registering a custom ``__subclasscheck__`` method for ``cls``, - for use in pattern matching with :func:`deep_issubclass`. + Decorator for registering a custom ``__subclasscheck__`` method for ``cls`` + which is only ever invoked in :func:`deep_issubclass`. + + This is primarily intended for working with the `typing` library at runtime. + Prefer overriding ``__subclasscheck__`` in the usual way with a metaclass + where possible. """ def _fn(fn): @@ -56,17 +64,21 @@ def _fn(fn): return _fn -register_subclasscheck(typing.Any)(lambda a, b: True) +@register_subclasscheck(typing.Any) +def _subclasscheck_any(cls, subcls): + return True @register_subclasscheck(typing.Union) def _subclasscheck_union(cls, subcls): + """A basic ``__subclasscheck__`` method for :class:`typing.Union`.""" return any(deep_issubclass(subcls, arg) for arg in get_args(cls)) @register_subclasscheck(frozenset) @register_subclasscheck(typing.FrozenSet) def _subclasscheck_frozenset(cls, subcls): + """A basic ``__subclasscheck__`` method for :class:`typing.FrozenSet`.""" if not issubclass(get_origin(subcls), frozenset): return False @@ -87,6 +99,7 @@ def _subclasscheck_frozenset(cls, subcls): @register_subclasscheck(tuple) @register_subclasscheck(typing.Tuple) def _subclasscheck_tuple(cls, subcls): + """A basic ``__subclasscheck__`` method for :class:`typing.Tuple`.""" if not issubclass(get_origin(subcls), get_origin(cls)): return False @@ -117,7 +130,7 @@ def _subclasscheck_tuple(cls, subcls): @functools.lru_cache(maxsize=None) def deep_issubclass(subcls, cls): """replaces issubclass()""" - # return pytypes.is_subtype(subcls, cls) + # compare to pytypes.is_subtype(subcls, cls) # handle unpacking if isinstance(subcls, _RuntimeSubclassCheckMeta): @@ -142,7 +155,7 @@ def deep_issubclass(subcls, cls): def deep_isinstance(obj, cls): """replaces isinstance()""" - # return pytypes.is_of_type(obj, cls) + # compare to pytypes.is_of_type(obj, cls) return deep_issubclass(deep_type(obj), cls) @@ -173,6 +186,12 @@ def get_origin(tp): return tp if result is None else result +if sys.version_info[:2] >= (3, 7): # reuse upstream documentation if possible + get_args = functools.wraps(typing_extensions.get_args)(get_args) + get_origin = functools.wraps(typing_extensions.get_origin)(get_origin) + + +@functools.wraps(typing.get_type_hints) def get_type_hints(obj, globalns=None, localns=None, **kwargs): return typing.get_type_hints(obj, globalns=globalns, localns=localns, **kwargs) @@ -276,7 +295,21 @@ def __getitem__(cls, key): class Variadic(metaclass=_DeepVariadicSignatureType): """ - A typing-compatible drop-in replacement for multipledispatch.variadic.Variadic. + A typing-compatible drop-in replacement for :class:`multipledispatch.variadic.Variadic`. """ pass + + +__all__ = [ + "GenericTypeMeta", + "Variadic", + "deep_isinstance", + "deep_issubclass", + "deep_type", + "get_args", + "get_origin", + "get_type_hints", + "register_subclasscheck", + "typing_wrap", +] From d45984b29e0a14aa8d8af5b25547f739a1194d10 Mon Sep 17 00:00:00 2001 From: Eli Date: Fri, 19 Feb 2021 15:47:20 -0500 Subject: [PATCH 60/66] add docstring for deep_type --- funsor/typing.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/funsor/typing.py b/funsor/typing.py index d8c743bfb..984c33966 100644 --- a/funsor/typing.py +++ b/funsor/typing.py @@ -18,7 +18,14 @@ @functools.singledispatch def deep_type(obj): """ - An enhanced version of :func:`type`. + An enhanced version of :func:`type` that reconstructs structured ``typing`` types + for a limited set of immutable data structures, notably ``tuple`` and ``frozenset``. + Mostly intended for internal use in Funsor interpretation pattern-matching. + + Example:: + + assert deep_type((1, ("a",))) is typing.Tuple[int, typing.Tuple[str]] + assert deep_type(frozenset(["a"])) is typing.FrozenSet[str] """ # compare to pytypes.deep_type(obj) return type(obj) From 23793e38099d1e2793428f24def902d747dbaf33 Mon Sep 17 00:00:00 2001 From: Eli Date: Fri, 19 Feb 2021 16:10:25 -0500 Subject: [PATCH 61/66] add docstrings for deep_issubclass and deep_isisntance --- funsor/typing.py | 50 +++++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 47 insertions(+), 3 deletions(-) diff --git a/funsor/typing.py b/funsor/typing.py index 984c33966..50e6ffac1 100644 --- a/funsor/typing.py +++ b/funsor/typing.py @@ -136,7 +136,29 @@ def _subclasscheck_tuple(cls, subcls): @functools.lru_cache(maxsize=None) def deep_issubclass(subcls, cls): - """replaces issubclass()""" + """ + Enhanced version of :func:`issubclass` that can handle structured types, + including Funsor terms, :class:`typing.Tuple`s, and :class:`typing.FrozenSet`s. + + Does not support :class:`typing.TypeVar`s, arbitrary :class:`typing.Generic`s, + forward references, or mutable collection types like :class:`typing.List`. + Will attempt to fall back to :func:`issubclass` when it encounters a type in + ``subcls`` or ``cls`` that it does not understand. + + Usage:: + + class A: pass + class B(A): pass + + assert deep_issubclass(typing.Tuple[int, B], typing.Tuple[int, A]) + assert not deep_issubclass(typing.Tuple[int, A], typing.Tuple[int, B]) + + assert deep_issubclass(typing.Tuple[A, A], typing.Tuple[A, ...]) + assert not deep_issubclass(typing.Tuple[B], typing.Tuple[A, ...]) + + :param subcls: A class that may be a subclass of ``cls``. + :param cls: A class that may be a parent class of ``subcls``. + """ # compare to pytypes.is_subtype(subcls, cls) # handle unpacking @@ -161,9 +183,31 @@ def deep_issubclass(subcls, cls): def deep_isinstance(obj, cls): - """replaces isinstance()""" + """ + Enhanced version of :func:`isinstance` that can handle basic structured ``typing`` types, + including Funsor terms and other :class:`GenericTypeMeta` instances, + :class:`typing.Union`s, :class:`typing.Tuple`s, and :class:`typing.FrozenSet`s. + + Does not support :class:`typing.TypeVar`s, arbitrary :class:`typing.Generic`s, + forward references, or mutable generic collection types like :class:`typing.List`. + Will attempt to fall back to :func:`isinstance` when it encounters + an unsupported type in ``obj`` or ``cls``. + + Usage:: + + x = (1, ("a", "b")) + assert deep_isinstance(x, typing.Tuple[int, tuple]) + assert deep_isinstance(x, typing.Tuple[typing.Any, typing.Tuple[str, ...]]) + + :param obj: An object that may be an instance of ``cls``. + :param cls: A class that may be a parent class of ``obj``. + """ + # compare to pytypes.is_of_type(obj, cls) - return deep_issubclass(deep_type(obj), cls) + try: + return deep_issubclass(deep_type(obj), cls) + except TypeError: + return isinstance(obj, cls) def _type_to_typing(tp): From b59dc9b69a161d41fca103be407c6bff1255af93 Mon Sep 17 00:00:00 2001 From: Eli Date: Sun, 21 Feb 2021 00:11:09 -0500 Subject: [PATCH 62/66] default name in partialdispatcher --- funsor/registry.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/funsor/registry.py b/funsor/registry.py index f741c38b4..389c5dadd 100644 --- a/funsor/registry.py +++ b/funsor/registry.py @@ -13,9 +13,9 @@ class PartialDispatcher(Dispatcher): Wrapper to avoid appearance in stack traces. """ - def __init__(self, name, default=None): + def __init__(self, default=None): self.default = default if default is None else PartialDefault(default) - super().__init__(name) + super().__init__("PartialDispatcher") if default is not None: self.add(([object],), self.default) @@ -79,7 +79,7 @@ class KeyedRegistry(object): def __init__(self, default=None): # TODO make registry a WeakKeyDictionary self.default = default if default is None else PartialDefault(default) - self.registry = defaultdict(lambda: PartialDispatcher("f", default=default)) + self.registry = defaultdict(lambda: PartialDispatcher(default=default)) def register(self, key, *types): register = self.registry[get_origin(key)].register From f011f7b6e6c72c8e7238ceb95dc2f9db902ce334 Mon Sep 17 00:00:00 2001 From: Eli Date: Sun, 21 Feb 2021 00:14:55 -0500 Subject: [PATCH 63/66] add typing to docs --- docs/source/index.rst | 1 + docs/source/typing.rst | 8 ++++++++ 2 files changed, 9 insertions(+) create mode 100644 docs/source/typing.rst diff --git a/docs/source/index.rst b/docs/source/index.rst index 16b7f6252..5739fe22c 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -21,6 +21,7 @@ Funsor is a tensor-like library for functions and distributions affine factory testing + typing .. toctree:: :glob: diff --git a/docs/source/typing.rst b/docs/source/typing.rst new file mode 100644 index 000000000..77ba6fabe --- /dev/null +++ b/docs/source/typing.rst @@ -0,0 +1,8 @@ +Typing Utiltites +---------------------- +.. automodule:: funsor.typing + :members: + :undoc-members: + :show-inheritance: + :member-order: bysource + From 1fb58fffd38a0ccf27ac2de3220d8a963af34f13 Mon Sep 17 00:00:00 2001 From: Eli Date: Sun, 21 Feb 2021 00:17:03 -0500 Subject: [PATCH 64/66] docs reqs --- docs/requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/requirements.txt b/docs/requirements.txt index 9e039be4a..4c67cdac7 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -1,6 +1,6 @@ multipledispatch numpy>=1.7 opt_einsum>=2.3.2 -unification pytest>=4.1 makefun +typing_extensions From 4bfd09699916b65a6f021383fbd794571718eaa3 Mon Sep 17 00:00:00 2001 From: Eli Date: Sun, 21 Feb 2021 00:39:43 -0500 Subject: [PATCH 65/66] fix typing docs --- funsor/typing.py | 33 +++++++++++++++++---------------- 1 file changed, 17 insertions(+), 16 deletions(-) diff --git a/funsor/typing.py b/funsor/typing.py index 50e6ffac1..38e9adf2c 100644 --- a/funsor/typing.py +++ b/funsor/typing.py @@ -18,7 +18,7 @@ @functools.singledispatch def deep_type(obj): """ - An enhanced version of :func:`type` that reconstructs structured ``typing`` types + An enhanced version of :func:`type` that reconstructs structured :mod:`typing`` types for a limited set of immutable data structures, notably ``tuple`` and ``frozenset``. Mostly intended for internal use in Funsor interpretation pattern-matching. @@ -59,7 +59,7 @@ def register_subclasscheck(cls): Decorator for registering a custom ``__subclasscheck__`` method for ``cls`` which is only ever invoked in :func:`deep_issubclass`. - This is primarily intended for working with the `typing` library at runtime. + This is primarily intended for working with the :mod:`typing` library at runtime. Prefer overriding ``__subclasscheck__`` in the usual way with a metaclass where possible. """ @@ -78,14 +78,14 @@ def _subclasscheck_any(cls, subcls): @register_subclasscheck(typing.Union) def _subclasscheck_union(cls, subcls): - """A basic ``__subclasscheck__`` method for :class:`typing.Union`.""" + """A basic ``__subclasscheck__`` method for :class:`~typing.Union`.""" return any(deep_issubclass(subcls, arg) for arg in get_args(cls)) @register_subclasscheck(frozenset) @register_subclasscheck(typing.FrozenSet) def _subclasscheck_frozenset(cls, subcls): - """A basic ``__subclasscheck__`` method for :class:`typing.FrozenSet`.""" + """A basic ``__subclasscheck__`` method for :class:`~typing.FrozenSet`.""" if not issubclass(get_origin(subcls), frozenset): return False @@ -106,7 +106,7 @@ def _subclasscheck_frozenset(cls, subcls): @register_subclasscheck(tuple) @register_subclasscheck(typing.Tuple) def _subclasscheck_tuple(cls, subcls): - """A basic ``__subclasscheck__`` method for :class:`typing.Tuple`.""" + """A basic ``__subclasscheck__`` method for :class:`~typing.Tuple`.""" if not issubclass(get_origin(subcls), get_origin(cls)): return False @@ -138,10 +138,11 @@ def _subclasscheck_tuple(cls, subcls): def deep_issubclass(subcls, cls): """ Enhanced version of :func:`issubclass` that can handle structured types, - including Funsor terms, :class:`typing.Tuple`s, and :class:`typing.FrozenSet`s. + including Funsor terms, :class:`~typing.Tuple`, and :class:`~typing.FrozenSet`. - Does not support :class:`typing.TypeVar`s, arbitrary :class:`typing.Generic`s, - forward references, or mutable collection types like :class:`typing.List`. + Does not support more advanced :mod:`typing` features such as + :class:`~typing.TypeVar`, arbitrary :class:`~typing.Generic` subtypes, + forward references, or mutable collection types like :class:`~typing.List`. Will attempt to fall back to :func:`issubclass` when it encounters a type in ``subcls`` or ``cls`` that it does not understand. @@ -184,12 +185,12 @@ class B(A): pass def deep_isinstance(obj, cls): """ - Enhanced version of :func:`isinstance` that can handle basic structured ``typing`` types, - including Funsor terms and other :class:`GenericTypeMeta` instances, - :class:`typing.Union`s, :class:`typing.Tuple`s, and :class:`typing.FrozenSet`s. + Enhanced version of :func:`isinstance` that can handle basic structured :mod:`typing` types, + including Funsor terms and other :class:`~funsor.typing.GenericTypeMeta` instances, + :class:`~typing.Union`, :class:`~typing.Tuple`, and :class:`~typing.FrozenSet`. - Does not support :class:`typing.TypeVar`s, arbitrary :class:`typing.Generic`s, - forward references, or mutable generic collection types like :class:`typing.List`. + Does not support :class:`~typing.TypeVar`, arbitrary :class:`~typing.Generic`, + forward references, or mutable generic collection types like :class:`~typing.List`. Will attempt to fall back to :func:`isinstance` when it encounters an unsupported type in ``obj`` or ``cls``. @@ -254,7 +255,7 @@ def get_type_hints(obj, globalns=None, localns=None, **kwargs): class GenericTypeMeta(type): """ - Metaclass to support subtyping with parameters for pattern matching, e.g. Number[int, int]. + Metaclass to support subtyping with parameters for pattern matching, e.g. ``Number[int, int]``. """ def __init__(cls, name, bases, dct): @@ -331,7 +332,7 @@ def __subclasscheck__(cls, subcls): class typing_wrap(metaclass=_RuntimeSubclassCheckMeta): """ - Utility callable for overriding the runtime behavior of `typing` objects. + Utility callable for overriding the runtime behavior of :mod:`typing` objects. """ pass @@ -346,7 +347,7 @@ def __getitem__(cls, key): class Variadic(metaclass=_DeepVariadicSignatureType): """ - A typing-compatible drop-in replacement for :class:`multipledispatch.variadic.Variadic`. + A typing-compatible drop-in replacement for :class:`~multipledispatch.variadic.Variadic`. """ pass From c0f187d0379a777caf58c33f5ec6d5422fe049e0 Mon Sep 17 00:00:00 2001 From: Eli Date: Sun, 21 Feb 2021 00:51:07 -0500 Subject: [PATCH 66/66] remove name argument from partialdispatcher --- test/test_typing.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/test_typing.py b/test/test_typing.py index 90e916820..4a2393503 100644 --- a/test/test_typing.py +++ b/test/test_typing.py @@ -289,7 +289,7 @@ def f(*args): def test_dispatch_typing(): - f = PartialDispatcher("f", default=lambda *args: 1) + f = PartialDispatcher(lambda *args: 1) @f.register() def f2(a: int, b: int) -> int: @@ -321,7 +321,7 @@ def f5(a: Tuple[int, float], b: Tuple[int, float]) -> int: def test_variadic_dispatch_typing(): - f = PartialDispatcher("f", default=lambda *args: 1) + f = PartialDispatcher(lambda *args: 1) @f.register() def _(a: int, b: int) -> int: