Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
73 commits
Select commit Hold shift + click to select a range
9d653cb
Extract parametric type metaclass from FunsorMeta and domains
eb8680 Jan 19, 2021
0177b07
restore assertion in FunsorMeta.__getitem__
eb8680 Jan 19, 2021
1462277
switch to pytypes multipledispatch backend and use pytypes in generic…
eb8680 Jan 19, 2021
2d6ee23
pin dependencies using new pip syntax
eb8680 Jan 20, 2021
5b98eba
fix dependencies
eb8680 Jan 20, 2021
664f3ed
attempt to work around need to change multipledispatch and revert
eb8680 Jan 21, 2021
d850d07
refactor Contraction to non-variadic patterns and remove all uses of …
eb8680 Jan 21, 2021
be0a4cb
add funsor.typing module
eb8680 Jan 21, 2021
b2931f4
use funsor.typing
eb8680 Jan 21, 2021
24985db
uncomment pattern
eb8680 Jan 21, 2021
6538ba8
move more things to funsor.typing
eb8680 Jan 21, 2021
07fc6b7
organize code in funsor.typing
eb8680 Jan 21, 2021
b598204
nit
eb8680 Jan 21, 2021
7630c1d
change syntax of typing_wrap to callable
eb8680 Jan 22, 2021
09c2c8b
remove memoize
eb8680 Jan 22, 2021
4c8eb9e
attempt to use custom issubclass
eb8680 Jan 27, 2021
0e842c3
Add FrozenSet to deep_issubclass and start deep_type
eb8680 Jan 30, 2021
525de14
use deep_type in reflect
eb8680 Jan 30, 2021
c4e19ec
fix subclass tests
eb8680 Jan 30, 2021
2cf8aa9
remove pytypes dependency
eb8680 Jan 30, 2021
00fee93
revert nits
eb8680 Jan 30, 2021
c55eb8e
Revert Variadic removal
eb8680 Jan 30, 2021
1c950bf
lint
eb8680 Jan 30, 2021
1c2a2a8
Merge branch 'master' into typing-module
eb8680 Jan 30, 2021
0430944
add typing_extensions dependency
eb8680 Jan 30, 2021
22b180b
implement add instead of register in typingdispatcher and add Variadic
eb8680 Jan 31, 2021
6509089
use new variadic throughout
eb8680 Jan 31, 2021
dec06a2
nits
eb8680 Jan 31, 2021
38c40df
move dispatcher to registry
eb8680 Jan 31, 2021
8b9fc57
remove aliasing and move classname to FunsorMeta
eb8680 Jan 31, 2021
d212394
fix typo
eb8680 Jan 31, 2021
c36dd48
Remove changes to domains from this branch
eb8680 Jan 31, 2021
0b71b48
keep classname lazy
eb8680 Jan 31, 2021
78bb134
run black
eb8680 Feb 2, 2021
418963b
reformat with black and merge master
eb8680 Feb 2, 2021
bac807c
fix import
eb8680 Feb 2, 2021
e2a2428
Merge branch 'master' into typing-module-funsormeta
eb8680 Feb 3, 2021
f007342
lint
eb8680 Feb 3, 2021
1310bb5
try to fix python 3.6
eb8680 Feb 3, 2021
646d8e5
Add test stages with other python versions to travis
eb8680 Feb 3, 2021
fd96c6b
Merge branch 'python-versions-travis' into typing-module-funsormeta
eb8680 Feb 3, 2021
2d79c0a
attempt to fix
eb8680 Feb 3, 2021
49627c5
Merge branch 'master' into typing-module-funsormeta
eb8680 Feb 17, 2021
01016b9
remove parametric subclass tests from test_terms
eb8680 Feb 17, 2021
ee1a611
add new test_typing.py
eb8680 Feb 17, 2021
cb5eea5
get_origin
eb8680 Feb 17, 2021
107aebe
more get_origin
eb8680 Feb 17, 2021
b6d2d46
add more tests
eb8680 Feb 17, 2021
b773aa4
fixes for variadic dispatch
eb8680 Feb 17, 2021
37d7022
add another dispatch test with no variadic patterns
eb8680 Feb 17, 2021
b52330f
Merge branch 'master' into typing-module-funsormeta
eb8680 Feb 19, 2021
9a73f56
add python 3.7,8,9 stages to github actions
eb8680 Feb 19, 2021
d8e0487
split up deep_issubclass
eb8680 Feb 19, 2021
90566d1
dont use defaultdict
eb8680 Feb 19, 2021
277d34c
typo
eb8680 Feb 19, 2021
d54703c
register_subclasscheck
eb8680 Feb 19, 2021
a5d64d4
fix tests
eb8680 Feb 19, 2021
c62e991
add test for get_type_hints
eb8680 Feb 19, 2021
1b9d7ff
handle no return hint
eb8680 Feb 19, 2021
30be93a
rename test
eb8680 Feb 19, 2021
e4c6a1f
attempt to use matrix syntax in github actions
eb8680 Feb 19, 2021
adc1ac0
fix union
eb8680 Feb 19, 2021
b1e8125
typing.get_type_hints
eb8680 Feb 19, 2021
e6710e1
small optimizations
eb8680 Feb 19, 2021
63d2c72
add some documentation
eb8680 Feb 19, 2021
d45984b
add docstring for deep_type
eb8680 Feb 19, 2021
23793e3
add docstrings for deep_issubclass and deep_isisntance
eb8680 Feb 19, 2021
b59dc9b
default name in partialdispatcher
eb8680 Feb 21, 2021
f011f7b
add typing to docs
eb8680 Feb 21, 2021
1fb58ff
docs reqs
eb8680 Feb 21, 2021
4bfd096
fix typing docs
eb8680 Feb 21, 2021
115d188
Merge branch 'master' into typing-module-funsormeta
eb8680 Feb 21, 2021
c0f187d
remove name argument from partialdispatcher
eb8680 Feb 21, 2021
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: [3.6]
python-version: [3.6, 3.7, 3.8, 3.9]

steps:
- uses: actions/checkout@v2
Expand Down
2 changes: 1 addition & 1 deletion docs/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
multipledispatch
numpy>=1.7
opt_einsum>=2.3.2
unification
pytest>=4.1
makefun
typing_extensions
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ Funsor is a tensor-like library for functions and distributions
affine
factory
testing
typing

.. toctree::
:glob:
Expand Down
8 changes: 8 additions & 0 deletions docs/source/typing.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
Typing Utiltites
----------------------
.. automodule:: funsor.typing
:members:
:undoc-members:
:show-inheritance:
:member-order: bysource

2 changes: 1 addition & 1 deletion funsor/cnf.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from typing import Tuple, Union

import opt_einsum
from multipledispatch.variadic import Variadic

import funsor
import funsor.ops as ops
Expand All @@ -32,6 +31,7 @@
Variable,
to_funsor,
)
from funsor.typing import Variadic
from funsor.util import broadcast_shape, get_backend, quote


Expand Down
2 changes: 1 addition & 1 deletion funsor/joint.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from typing import Tuple, Union

from multipledispatch import dispatch
from multipledispatch.variadic import Variadic

import funsor.ops as ops
from funsor.cnf import Contraction, GaussianMixture
Expand All @@ -18,6 +17,7 @@
from funsor.ops import AssociativeOp
from funsor.tensor import Tensor, align_tensor
from funsor.terms import Funsor, Independent, Number, Reduce, Unary
from funsor.typing import Variadic


@dispatch(str, str, Variadic[(Gaussian, GaussianMixture)])
Expand Down
2 changes: 1 addition & 1 deletion funsor/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

import collections

from multipledispatch.variadic import Variadic
from opt_einsum.paths import greedy

import funsor.interpreter as interpreter
Expand All @@ -18,6 +17,7 @@
from funsor.interpreter import get_interpretation
from funsor.ops import DISTRIBUTIVE_OPS, AssociativeOp
from funsor.terms import Funsor
from funsor.typing import Variadic

unfold_base = DispatchedInterpretation()
unfold = PrioritizedInterpretation(unfold_base, normalize_base, lazy)
Expand Down
57 changes: 41 additions & 16 deletions funsor/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,50 @@

from collections import defaultdict

from multipledispatch import Dispatcher
from multipledispatch.conflict import supercedes
from multipledispatch.dispatcher import Dispatcher, expand_tuples

from funsor.typing import Variadic, deep_type, get_origin, get_type_hints, typing_wrap


class PartialDispatcher(Dispatcher):
"""
Wrapper to avoid appearance in stack traces.
"""

def __init__(self, default=None):
self.default = default if default is None else PartialDefault(default)
super().__init__("PartialDispatcher")
if default is not None:
self.add(([object],), self.default)

def add(self, signature, func):

# Handle annotations
if not signature:
annotations = get_type_hints(func)
annotations.pop("return", None)
if annotations:
signature = tuple(annotations.values())

# Handle some union types by expanding at registration time
if any(isinstance(typ, tuple) for typ in signature):
for typs in expand_tuples(signature):
self.add(typs, func)
return

# Handle variadic types
signature = (
Variadic[tuple(tp)] if isinstance(tp, list) else tp for tp in signature
)

signature = tuple(map(typing_wrap, signature))
super().add(signature, func)

def partial_call(self, *args):
"""
Likde :meth:`__call__` but avoids calling ``func()``.
"""
types = tuple(map(type, args))
types = tuple(map(typing_wrap, map(deep_type, args)))
try:
func = self._cache[types]
except KeyError:
Expand All @@ -29,6 +59,9 @@ def partial_call(self, *args):
self._cache[types] = func
return func

def __call__(self, *args):
return self.partial_call(*args)(*args)


class PartialDefault:
def __init__(self, default):
Expand All @@ -44,20 +77,12 @@ def partial_call(self, *args):

class KeyedRegistry(object):
def __init__(self, default=None):
self.default = default if default is None else PartialDefault(default)
# TODO make registry a WeakKeyDictionary
self.registry = defaultdict(lambda: PartialDispatcher("f"))
self.default = default if default is None else PartialDefault(default)
self.registry = defaultdict(lambda: PartialDispatcher(default=default))

def register(self, key, *types):
key = getattr(key, "__origin__", key)
register = self.registry[key].register
if self.default:
objects = (object,) * len(types)
try:
if objects != types and supercedes(types, objects):
register(*objects)(self.default)
except TypeError:
pass # mysterious source of ambiguity in Python 3.5 breaks this
register = self.registry[get_origin(key)].register

# This decorator supports stacking multiple decorators, which is not
# supported by multipledipatch (which returns a Dispatch object rather
Expand All @@ -69,10 +94,10 @@ def decorator(fn):
return decorator

def __contains__(self, key):
return key in self.registry
return get_origin(key) in self.registry

def __getitem__(self, key):
key = getattr(key, "__origin__", key)
key = get_origin(key)
if self.default is None:
return self.registry[key]
return self.registry.get(key, self.default)
Expand Down
2 changes: 1 addition & 1 deletion funsor/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import numpy as np
import opt_einsum
from multipledispatch import dispatch
from multipledispatch.variadic import Variadic

import funsor

Expand All @@ -35,6 +34,7 @@
to_data,
to_funsor,
)
from .typing import Variadic
from .util import get_backend, get_tracing_state, getargspec, is_nn_module, quote


Expand Down
130 changes: 15 additions & 115 deletions funsor/terms.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from weakref import WeakValueDictionary

from multipledispatch import dispatch
from multipledispatch.variadic import Variadic, isvariadic

import funsor.interpreter as interpreter
import funsor.ops as ops
Expand All @@ -29,6 +28,7 @@
from funsor.interpreter import PatternMissingError, interpret
from funsor.ops import AssociativeOp, GetitemOp, Op
from funsor.syntax import INFIX_OPERATORS, PREFIX_OPERATORS
from funsor.typing import GenericTypeMeta, Variadic, deep_type, get_origin
from funsor.util import getargspec, lazy_property, pretty, quote

from . import instrument, interpreter, ops
Expand Down Expand Up @@ -105,23 +105,16 @@ def reflect(cls, *args, **kwargs):
if cache_key in cls._cons_cache:
return cls._cons_cache[cache_key]

arg_types = tuple(
typing.Tuple[tuple(map(type, arg))]
if (type(arg) is tuple and all(isinstance(a, Funsor) for a in arg))
else typing.Tuple
if (type(arg) is tuple and not arg)
else type(arg)
for arg in args
)
cls_specific = (cls.__origin__ if cls.__args__ else cls)[arg_types]
arg_types = tuple(map(deep_type, args))
cls_specific = get_origin(cls)[arg_types]
result = super(FunsorMeta, cls_specific).__call__(*args)
result._ast_values = args

if instrument.PROFILE:
size, depth, width = _get_ast_stats(result)
instrument.COUNTERS["ast_size"][size] += 1
instrument.COUNTERS["ast_depth"][depth] += 1
classname = getattr(cls, "__origin__", cls).__name__
classname = get_origin(cls).__name__
instrument.COUNTERS["funsor"][classname] += 1
instrument.COUNTERS[classname][width] += 1

Expand All @@ -132,7 +125,7 @@ def reflect(cls, *args, **kwargs):
return result


class FunsorMeta(type):
class FunsorMeta(GenericTypeMeta):
"""
Metaclass for Funsors to perform four independent tasks:

Expand All @@ -157,15 +150,17 @@ class FunsorMeta(type):

def __init__(cls, name, bases, dct):
super(FunsorMeta, cls).__init__(name, bases, dct)
if not hasattr(cls, "__args__"):
cls.__args__ = ()
if cls.__args__:
(base,) = bases
cls.__origin__ = base
else:
if not cls.__args__:
cls._ast_fields = getargspec(cls.__init__)[0][1:]
cls._cons_cache = WeakValueDictionary()
cls._type_cache = WeakValueDictionary()

def __getitem__(cls, arg_types):
if not isinstance(arg_types, tuple):
arg_types = (arg_types,)
assert len(arg_types) == len(
cls._ast_fields
), "Must provide exactly one type per subexpression"
return super().__getitem__(arg_types)

def __call__(cls, *args, **kwargs):
if cls.__args__:
Expand All @@ -181,104 +176,9 @@ def __call__(cls, *args, **kwargs):

return interpret(cls, *args)

def __getitem__(cls, arg_types):
if not isinstance(arg_types, tuple):
arg_types = (arg_types,)
assert not any(
isvariadic(arg_type) for arg_type in arg_types
), "nested variadic types not supported"
# switch tuple to typing.Tuple
arg_types = tuple(
typing.Tuple if arg_type is tuple else arg_type for arg_type in arg_types
)
if arg_types not in cls._type_cache:
assert not cls.__args__, "cannot subscript a subscripted type {}".format(
cls
)
assert len(arg_types) == len(
cls._ast_fields
), "must provide types for all params"
new_dct = cls.__dict__.copy()
new_dct.update({"__args__": arg_types})
# type(cls) to handle FunsorMeta subclasses
cls._type_cache[arg_types] = type(cls)(cls.__name__, (cls,), new_dct)
return cls._type_cache[arg_types]

def __subclasscheck__(cls, subcls): # issubclass(subcls, cls)
if cls is subcls:
return True
if not isinstance(subcls, FunsorMeta):
return super(FunsorMeta, getattr(cls, "__origin__", cls)).__subclasscheck__(
subcls
)

cls_origin = getattr(cls, "__origin__", cls)
subcls_origin = getattr(subcls, "__origin__", subcls)
if not super(FunsorMeta, cls_origin).__subclasscheck__(subcls_origin):
return False

if cls.__args__:
if not subcls.__args__:
return False
if len(cls.__args__) != len(subcls.__args__):
return False
for subcls_param, param in zip(subcls.__args__, cls.__args__):
if not _issubclass_tuple(subcls_param, param):
return False
return True

@lazy_property
def classname(cls):
return cls.__name__ + "[{}]".format(
", ".join(
str(getattr(t, "classname", t)) # Tuple doesn't have __name__
for t in cls.__args__
)
)


def _issubclass_tuple(subcls, cls):
"""
utility for pattern matching with tuple subexpressions
"""
# so much boilerplate...
cls_is_union = (
hasattr(cls, "__origin__") and (cls.__origin__ or cls) is typing.Union
)
if isinstance(cls, tuple) or cls_is_union:
return any(
_issubclass_tuple(subcls, option)
for option in (getattr(cls, "__args__", []) if cls_is_union else cls)
)

subcls_is_union = (
hasattr(subcls, "__origin__") and (subcls.__origin__ or subcls) is typing.Union
)
if isinstance(subcls, tuple) or subcls_is_union:
return any(
_issubclass_tuple(option, cls)
for option in (
getattr(subcls, "__args__", []) if subcls_is_union else subcls
)
)

subcls_is_tuple = hasattr(subcls, "__origin__") and (
subcls.__origin__ or subcls
) in (tuple, typing.Tuple)
cls_is_tuple = hasattr(cls, "__origin__") and (cls.__origin__ or cls) in (
tuple,
typing.Tuple,
)
if subcls_is_tuple != cls_is_tuple:
return False
if not cls_is_tuple:
return issubclass(subcls, cls)
if not cls.__args__:
return True
if not subcls.__args__ or len(subcls.__args__) != len(cls.__args__):
return False

return all(_issubclass_tuple(a, b) for a, b in zip(subcls.__args__, cls.__args__))
return repr(cls)


def _convert_reduced_vars(reduced_vars, inputs):
Expand Down
Loading