From 89e7abb86933e1d0b83c106c2eef7b9236713030 Mon Sep 17 00:00:00 2001 From: Eli Date: Wed, 13 Jan 2021 20:13:16 -0500 Subject: [PATCH 1/6] Add Finitary funsor for lazy op application --- funsor/affine.py | 10 +++---- funsor/tensor.py | 71 +++++++++++++++++++++++++++--------------------- funsor/terms.py | 15 ++++++++++ 3 files changed, 60 insertions(+), 36 deletions(-) diff --git a/funsor/affine.py b/funsor/affine.py index 630502ec9..86e2fabf2 100644 --- a/funsor/affine.py +++ b/funsor/affine.py @@ -7,8 +7,8 @@ import opt_einsum from funsor.interpreter import gensym -from funsor.tensor import Einsum, Tensor, get_default_prototype -from funsor.terms import Binary, Funsor, Lambda, Reduce, Unary, Variable, Bint +from funsor.tensor import EinsumOp, Tensor, get_default_prototype +from funsor.terms import Binary, Finitary, Funsor, Lambda, Reduce, Unary, Variable, Bint from . import ops @@ -91,12 +91,12 @@ def _(fn): return affine_inputs(fn.arg) - fn.reduced_vars -@affine_inputs.register(Einsum) +@affine_inputs.register(Finitary[EinsumOp, tuple]) def _(fn): # This is simply a multiary version of the above Binary(ops.mul, ...) case. results = [] - for i, x in enumerate(fn.operands): - others = fn.operands[:i] + fn.operands[i+1:] + for i, x in enumerate(fn.args): + others = fn.args[:i] + fn.args[i+1:] other_inputs = reduce(ops.or_, map(_real_inputs, others), frozenset()) results.append(affine_inputs(x) - other_inputs) # This multilinear case introduces incompleteness, since some vars diff --git a/funsor/tensor.py b/funsor/tensor.py index 0dc1a9542..ca9724328 100644 --- a/funsor/tensor.py +++ b/funsor/tensor.py @@ -21,6 +21,7 @@ from funsor.ops import GetitemOp, MatmulOp, Op, ReshapeOp from funsor.terms import ( Binary, + Finitary, Funsor, FunsorMeta, Lambda, @@ -694,6 +695,18 @@ def eager_getitem_tensor_tensor(op, lhs, rhs): return Tensor(data, inputs, lhs.dtype) +# TODO handle variable length Tuple +@eager.register(Finitary, Op, typing.Tuple[Tensor]) +@eager.register(Finitary, Op, typing.Tuple[Tensor, Tensor]) +@eager.register(Finitary, Op, typing.Tuple[Tensor, Tensor, Tensor]) +@eager.register(Finitary, Op, typing.Tuple[Tensor, Tensor, Tensor, Tensor]) +@eager.register(Finitary, Op, typing.Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]) +def eager_finitary_generic_tensors(op, args): + inputs, raw_args = align_tensors(*args) + raw_result = op(*raw_args) + return Tensor(raw_result, inputs, args[0].dtype) + + @eager.register(Lambda, Variable, Tensor) def eager_lambda(var, expr): inputs = expr.inputs.copy() @@ -949,7 +962,29 @@ def max_and_argmax(x: Reals[8]) -> Tuple[Real, Bint[8]]: return functools.partial(_function, inputs, output) -class Einsum(Funsor): +class EinsumOp(ops.Op, metaclass=ops.CachedOpMeta): + def __init__(self, equation): + self.equation = equation + + +@find_domain.register(EinsumOp) +def _find_domain_einsum(op, *operands): + equation = op.equation + ein_inputs, ein_output = equation.split('->') + ein_inputs = ein_inputs.split(',') + size_dict = {} + for ein_input, x in zip(ein_inputs, operands): + assert x.dtype == 'real' + assert len(ein_input) == len(x.output.shape) + for name, size in zip(ein_input, x.output.shape): + other_size = size_dict.setdefault(name, size) + if other_size != size: + raise ValueError("Size mismatch at {}: {} vs {}" + .format(name, size, other_size)) + return Reals[tuple(size_dict[d] for d in ein_output)] + + +def Einsum(equation, operands): """ Wrapper around :func:`torch.einsum` or :func:`np.einsum` to operate on real-valued Funsors. @@ -960,40 +995,14 @@ class Einsum(Funsor): :param str equation: An :func:`torch.einsum` or :func:`np.einsum` equation. :param tuple operands: A tuple of input funsors. """ - def __init__(self, equation, operands): - assert isinstance(equation, str) - assert isinstance(operands, tuple) - assert all(isinstance(x, Funsor) for x in operands) - ein_inputs, ein_output = equation.split('->') - ein_inputs = ein_inputs.split(',') - size_dict = {} - inputs = OrderedDict() - assert len(ein_inputs) == len(operands) - for ein_input, x in zip(ein_inputs, operands): - assert x.dtype == 'real' - inputs.update(x.inputs) - assert len(ein_input) == len(x.output.shape) - for name, size in zip(ein_input, x.output.shape): - other_size = size_dict.setdefault(name, size) - if other_size != size: - raise ValueError("Size mismatch at {}: {} vs {}" - .format(name, size, other_size)) - output = Reals[tuple(size_dict[d] for d in ein_output)] - super(Einsum, self).__init__(inputs, output) - self.equation = equation - self.operands = operands - - def __repr__(self): - return 'Einsum({}, {})'.format(repr(self.equation), repr(self.operands)) - - def __str__(self): - return 'Einsum({}, {})'.format(repr(self.equation), str(self.operands)) + return Finitary(EinsumOp(equation), tuple(operands)) -@eager.register(Einsum, str, tuple) -def eager_einsum(equation, operands): +@eager.register(Finitary, EinsumOp, tuple) +def eager_einsum(op, operands): if all(isinstance(x, Tensor) for x in operands): # Make new symbols for inputs of operands. + equation = op.equation inputs = OrderedDict() for x in operands: inputs.update(x.inputs) diff --git a/funsor/terms.py b/funsor/terms.py index 97b8cb740..634491480 100644 --- a/funsor/terms.py +++ b/funsor/terms.py @@ -1309,6 +1309,21 @@ def eager_binary_align_align(op, lhs, rhs): return Binary(op, lhs.arg, rhs.arg) +class Finitary(Funsor): + + def __init__(self, op, args): + assert isinstance(op, ops.Op) + assert isinstance(args, tuple) + assert all(isinstance(v, Funsor) for v in args) + inputs = OrderedDict() + for arg in args: + inputs.update(arg.inputs) + output = find_domain(op, *(arg.output for arg in args)) + super().__init__(inputs, output) + self.op = op + self.args = args + + class Stack(Funsor): """ Stack of funsors along a new input dimension. From 66587a9ffcb161a7d2be363b6ea73d71e92059b8 Mon Sep 17 00:00:00 2001 From: Eli Date: Sun, 21 Feb 2021 01:21:22 -0500 Subject: [PATCH 2/6] format --- funsor/affine.py | 2 +- funsor/tensor.py | 11 ++++++----- funsor/terms.py | 1 - 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/funsor/affine.py b/funsor/affine.py index d3c0ab31e..2cdb8578d 100644 --- a/funsor/affine.py +++ b/funsor/affine.py @@ -97,7 +97,7 @@ def _(fn): # This is simply a multiary version of the above Binary(ops.mul, ...) case. results = [] for i, x in enumerate(fn.args): - others = fn.args[:i] + fn.args[i + 1:] + others = fn.args[:i] + fn.args[i + 1 :] other_inputs = reduce(ops.or_, map(_real_inputs, others), frozenset()) results.append(affine_inputs(x) - other_inputs) # This multilinear case introduces incompleteness, since some vars diff --git a/funsor/tensor.py b/funsor/tensor.py index fc598a27e..1ef8bb3bf 100644 --- a/funsor/tensor.py +++ b/funsor/tensor.py @@ -1040,17 +1040,18 @@ def __init__(self, equation): @find_domain.register(EinsumOp) def _find_domain_einsum(op, *operands): equation = op.equation - ein_inputs, ein_output = equation.split('->') - ein_inputs = ein_inputs.split(',') + ein_inputs, ein_output = equation.split("->") + ein_inputs = ein_inputs.split(",") size_dict = {} for ein_input, x in zip(ein_inputs, operands): - assert x.dtype == 'real' + assert x.dtype == "real" assert len(ein_input) == len(x.output.shape) for name, size in zip(ein_input, x.output.shape): other_size = size_dict.setdefault(name, size) if other_size != size: - raise ValueError("Size mismatch at {}: {} vs {}" - .format(name, size, other_size)) + raise ValueError( + "Size mismatch at {}: {} vs {}".format(name, size, other_size) + ) return Reals[tuple(size_dict[d] for d in ein_output)] diff --git a/funsor/terms.py b/funsor/terms.py index c28981300..b8d91818e 100644 --- a/funsor/terms.py +++ b/funsor/terms.py @@ -1289,7 +1289,6 @@ def eager_binary_align_align(op, lhs, rhs): class Finitary(Funsor): - def __init__(self, op, args): assert isinstance(op, ops.Op) assert isinstance(args, tuple) From d19244565956cd87bbeb7995c3210c96a15ea8ca Mon Sep 17 00:00:00 2001 From: Eli Date: Sun, 21 Feb 2021 01:22:16 -0500 Subject: [PATCH 3/6] variadic tuple --- funsor/tensor.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/funsor/tensor.py b/funsor/tensor.py index 1ef8bb3bf..aee6b85a0 100644 --- a/funsor/tensor.py +++ b/funsor/tensor.py @@ -771,12 +771,7 @@ def eager_getitem_tensor_tensor(op, lhs, rhs): return Tensor(data, inputs, lhs.dtype) -# TODO handle variable length Tuple -@eager.register(Finitary, Op, typing.Tuple[Tensor]) -@eager.register(Finitary, Op, typing.Tuple[Tensor, Tensor]) -@eager.register(Finitary, Op, typing.Tuple[Tensor, Tensor, Tensor]) -@eager.register(Finitary, Op, typing.Tuple[Tensor, Tensor, Tensor, Tensor]) -@eager.register(Finitary, Op, typing.Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]) +@eager.register(Finitary, Op, typing.Tuple[Tensor, ...]) def eager_finitary_generic_tensors(op, args): inputs, raw_args = align_tensors(*args) raw_result = op(*raw_args) From bb01e5f333d68928f09533bd181086361b862304 Mon Sep 17 00:00:00 2001 From: Eli Date: Sun, 21 Feb 2021 01:34:03 -0500 Subject: [PATCH 4/6] fix tests --- funsor/tensor.py | 63 +++++++++++++++++++++------------------------ test/test_affine.py | 4 +-- 2 files changed, 32 insertions(+), 35 deletions(-) diff --git a/funsor/tensor.py b/funsor/tensor.py index aee6b85a0..3d0754ee4 100644 --- a/funsor/tensor.py +++ b/funsor/tensor.py @@ -1040,8 +1040,8 @@ def _find_domain_einsum(op, *operands): size_dict = {} for ein_input, x in zip(ein_inputs, operands): assert x.dtype == "real" - assert len(ein_input) == len(x.output.shape) - for name, size in zip(ein_input, x.output.shape): + assert len(ein_input) == len(x.shape) + for name, size in zip(ein_input, x.shape): other_size = size_dict.setdefault(name, size) if other_size != size: raise ValueError( @@ -1064,39 +1064,36 @@ def Einsum(equation, operands): return Finitary(EinsumOp(equation), tuple(operands)) -@eager.register(Finitary, EinsumOp, tuple) +@eager.register(Finitary, EinsumOp, typing.Tuple[Tensor, ...]) def eager_einsum(op, operands): - if all(isinstance(x, Tensor) for x in operands): - # Make new symbols for inputs of operands. - equation = op.equation - inputs = OrderedDict() - for x in operands: - inputs.update(x.inputs) - symbols = set(equation) - get_symbol = iter(map(opt_einsum.get_symbol, itertools.count())) - new_symbols = {} - for k in inputs: + # Make new symbols for inputs of operands. + equation = op.equation + inputs = OrderedDict() + for x in operands: + inputs.update(x.inputs) + symbols = set(equation) + get_symbol = iter(map(opt_einsum.get_symbol, itertools.count())) + new_symbols = {} + for k in inputs: + symbol = next(get_symbol) + while symbol in symbols: symbol = next(get_symbol) - while symbol in symbols: - symbol = next(get_symbol) - symbols.add(symbol) - new_symbols[k] = symbol - - # Manually broadcast using einsum symbols. - assert "." not in equation - ins, out = equation.split("->") - ins = ins.split(",") - ins = [ - "".join(new_symbols[k] for k in x.inputs) + x_out - for x, x_out in zip(operands, ins) - ] - out = "".join(new_symbols[k] for k in inputs) + out - equation = ",".join(ins) + "->" + out - - data = ops.einsum(equation, *[x.data for x in operands]) - return Tensor(data, inputs) - - return None # defer to default implementation + symbols.add(symbol) + new_symbols[k] = symbol + + # Manually broadcast using einsum symbols. + assert "." not in equation + ins, out = equation.split("->") + ins = ins.split(",") + ins = [ + "".join(new_symbols[k] for k in x.inputs) + x_out + for x, x_out in zip(operands, ins) + ] + out = "".join(new_symbols[k] for k in inputs) + out + equation = ",".join(ins) + "->" + out + + data = ops.einsum(equation, *[x.data for x in operands]) + return Tensor(data, inputs) def tensordot(x, y, dims): diff --git a/test/test_affine.py b/test/test_affine.py index a991ea3c8..4cc46a254 100644 --- a/test/test_affine.py +++ b/test/test_affine.py @@ -9,7 +9,7 @@ from funsor.cnf import Contraction from funsor.domains import Bint, Real, Reals # noqa: F401 from funsor.tensor import Einsum, Tensor -from funsor.terms import Number, Unary, Variable +from funsor.terms import Finitary, Number, Unary, Variable from funsor.testing import ( assert_close, check_funsor, @@ -117,7 +117,7 @@ def test_affine_subs(expr, expected_type, expected_inputs): def test_extract_affine(expr): x = eval(expr) assert is_affine(x) - assert isinstance(x, (Unary, Contraction, Einsum)) + assert isinstance(x, (Unary, Contraction, Finitary)) real_inputs = OrderedDict((k, d) for k, d in x.inputs.items() if d.dtype == "real") const, coeffs = extract_affine(x) From 14659f0efb874c5291798d5e6d9ff72083cce756 Mon Sep 17 00:00:00 2001 From: Eli Date: Sun, 21 Feb 2021 02:00:06 -0500 Subject: [PATCH 5/6] fix mvn_affine test --- test/test_distribution.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_distribution.py b/test/test_distribution.py index d0794a1eb..b01a5da6a 100644 --- a/test/test_distribution.py +++ b/test/test_distribution.py @@ -639,7 +639,7 @@ def test_mvn_affine_einsum(): data = dict(x=Tensor(randn(2, 2)), y=Tensor(randn(()))) with lazy: d = to_funsor(random_mvn((), 3), Real) - d = d(value=Einsum("abc,bc->a", c, x) + y) + d = d(value=Einsum("abc,bc->a", (c, x)) + y) _check_mvn_affine(d, data) From 40810e52ee4eca4f57d4e89289177177e0507fff Mon Sep 17 00:00:00 2001 From: Eli Date: Sun, 21 Feb 2021 22:40:36 -0500 Subject: [PATCH 6/6] update Einsum interface --- funsor/affine.py | 2 +- funsor/tensor.py | 4 ++-- test/test_affine.py | 8 ++++---- test/test_distribution.py | 2 +- test/test_gaussian.py | 2 +- test/test_tensor.py | 4 ++-- 6 files changed, 11 insertions(+), 11 deletions(-) diff --git a/funsor/affine.py b/funsor/affine.py index bfa1a9229..cf8f695ab 100644 --- a/funsor/affine.py +++ b/funsor/affine.py @@ -115,7 +115,7 @@ def extract_affine(fn): x = ... const, coeffs = extract_affine(x) - y = sum(Einsum(eqn, (coeff, Variable(var, coeff.output))) + y = sum(Einsum(eqn, coeff, Variable(var, coeff.output)) for var, (coeff, eqn) in coeffs.items()) assert_close(y, x) assert frozenset(coeffs) == affine_inputs(x) diff --git a/funsor/tensor.py b/funsor/tensor.py index 3d0754ee4..7db22769b 100644 --- a/funsor/tensor.py +++ b/funsor/tensor.py @@ -1050,7 +1050,7 @@ def _find_domain_einsum(op, *operands): return Reals[tuple(size_dict[d] for d in ein_output)] -def Einsum(equation, operands): +def Einsum(equation, *operands): """ Wrapper around :func:`torch.einsum` or :func:`np.einsum` to operate on real-valued Funsors. @@ -1129,7 +1129,7 @@ def tensordot(x, y, dims): symbols[y_start:y_end], symbols[x_start:y_start] + symbols[x_end:y_end], ) - return Einsum(equation, (x, y)) + return Einsum(equation, x, y) def stack(parts, dim=0): diff --git a/test/test_affine.py b/test/test_affine.py index 4cc46a254..70e811099 100644 --- a/test/test_affine.py +++ b/test/test_affine.py @@ -107,9 +107,9 @@ def test_affine_subs(expr, expected_type, expected_inputs): "Variable('x', Reals[2]) * randn(2) + ones(2)", "Variable('x', Reals[2]) + Tensor(randn(3, 2), OrderedDict(i=Bint[3]))", "Einsum('abcd,ac->bd'," - " (Tensor(randn(2, 3, 4, 5)), Variable('x', Reals[2, 4])))", + " Tensor(randn(2, 3, 4, 5)), Variable('x', Reals[2, 4]))", "Tensor(randn(3, 5)) + Einsum('abcd,ac->bd'," - " (Tensor(randn(2, 3, 4, 5)), Variable('x', Reals[2, 4])))", + " Tensor(randn(2, 3, 4, 5)), Variable('x', Reals[2, 4]))", "Variable('x', Reals[2, 8])[0] + randn(8)", "Variable('x', Reals[2, 8])[Variable('i', Bint[2])] / 4 - 3.5", ], @@ -134,7 +134,7 @@ def test_extract_affine(expr): assert isinstance(expected, Tensor) actual = const + sum( - Einsum(eqn, (coeff, subs[k])) for k, (coeff, eqn) in coeffs.items() + Einsum(eqn, coeff, subs[k]) for k, (coeff, eqn) in coeffs.items() ) assert isinstance(actual, Tensor) assert_close(actual, expected) @@ -157,7 +157,7 @@ def test_extract_affine(expr): "Variable('x', Reals[2,3]) @ Variable('y', Reals[3,4])", "random_gaussian(OrderedDict(x=Real))", "Einsum('abcd,ac->bd'," - " (Variable('y', Reals[2, 3, 4, 5]), Variable('x', Reals[2, 4])))", + " Variable('y', Reals[2, 3, 4, 5]), Variable('x', Reals[2, 4]))", ], ) def test_not_is_affine(expr): diff --git a/test/test_distribution.py b/test/test_distribution.py index 4e25374eb..58a402e35 100644 --- a/test/test_distribution.py +++ b/test/test_distribution.py @@ -639,7 +639,7 @@ def test_mvn_affine_einsum(): data = dict(x=Tensor(randn(2, 2)), y=Tensor(randn(()))) with lazy: d = to_funsor(random_mvn((), 3), Real) - d = d(value=Einsum("abc,bc->a", (c, x)) + y) + d = d(value=Einsum("abc,bc->a", c, x) + y) _check_mvn_affine(d, data) diff --git a/test/test_gaussian.py b/test/test_gaussian.py index d9e66af02..264953907 100644 --- a/test/test_gaussian.py +++ b/test/test_gaussian.py @@ -368,7 +368,7 @@ def test_eager_subs_variable(): ( ( "y", - 'Einsum("abc,bc->a", (Tensor(randn((4, 3, 5))), Variable("v", Reals[3, 5])))', + 'Einsum("abc,bc->a", Tensor(randn((4, 3, 5))), Variable("v", Reals[3, 5]))', ), ), ], diff --git a/test/test_tensor.py b/test/test_tensor.py index 2050fb105..5a6e62f1e 100644 --- a/test/test_tensor.py +++ b/test/test_tensor.py @@ -950,7 +950,7 @@ def test_einsum(equation): tensors = [randn(tuple(sizes[d] for d in dims)) for dims in inputs] funsors = [Tensor(x) for x in tensors] expected = Tensor(ops.einsum(equation, *tensors)) - actual = Einsum(equation, tuple(funsors)) + actual = Einsum(equation, *funsors) assert_close(actual, expected, atol=1e-5, rtol=None) @@ -968,7 +968,7 @@ def test_batched_einsum(equation, batch1, batch2): random_tensor(batch, Reals[tuple(sizes[d] for d in dims)]) for batch, dims in zip([batch1, batch2], inputs) ] - actual = Einsum(equation, tuple(funsors)) + actual = Einsum(equation, *funsors) _equation = ",".join("..." + i for i in inputs) + "->..." + output inputs, tensors = align_tensors(*funsors)