diff --git a/funsor/affine.py b/funsor/affine.py index ddcefcc2b..cf8f695ab 100644 --- a/funsor/affine.py +++ b/funsor/affine.py @@ -6,9 +6,10 @@ import opt_einsum +from funsor.domains import Bint from funsor.interpreter import gensym -from funsor.tensor import Einsum, Tensor, get_default_prototype -from funsor.terms import Binary, Bint, Funsor, Lambda, Reduce, Unary, Variable +from funsor.tensor import EinsumOp, Tensor, get_default_prototype +from funsor.terms import Binary, Finitary, Funsor, Lambda, Reduce, Unary, Variable from . import ops @@ -91,12 +92,12 @@ def _(fn): return affine_inputs(fn.arg) - fn.reduced_vars -@affine_inputs.register(Einsum) +@affine_inputs.register(Finitary[EinsumOp, tuple]) def _(fn): # This is simply a multiary version of the above Binary(ops.mul, ...) case. results = [] - for i, x in enumerate(fn.operands): - others = fn.operands[:i] + fn.operands[i + 1 :] + for i, x in enumerate(fn.args): + others = fn.args[:i] + fn.args[i + 1 :] other_inputs = reduce(ops.or_, map(_real_inputs, others), frozenset()) results.append(affine_inputs(x) - other_inputs) # This multilinear case introduces incompleteness, since some vars @@ -114,7 +115,7 @@ def extract_affine(fn): x = ... const, coeffs = extract_affine(x) - y = sum(Einsum(eqn, (coeff, Variable(var, coeff.output))) + y = sum(Einsum(eqn, coeff, Variable(var, coeff.output)) for var, (coeff, eqn) in coeffs.items()) assert_close(y, x) assert frozenset(coeffs) == affine_inputs(x) diff --git a/funsor/tensor.py b/funsor/tensor.py index 9ac05ac4d..7db22769b 100644 --- a/funsor/tensor.py +++ b/funsor/tensor.py @@ -21,6 +21,7 @@ from .ops import GetitemOp, MatmulOp, Op, ReshapeOp from .terms import ( Binary, + Finitary, Funsor, FunsorMeta, Lambda, @@ -770,6 +771,13 @@ def eager_getitem_tensor_tensor(op, lhs, rhs): return Tensor(data, inputs, lhs.dtype) +@eager.register(Finitary, Op, typing.Tuple[Tensor, ...]) +def eager_finitary_generic_tensors(op, args): + inputs, raw_args = align_tensors(*args) + raw_result = op(*raw_args) + return Tensor(raw_result, inputs, args[0].dtype) + + @eager.register(Lambda, Variable, Tensor) def eager_lambda(var, expr): inputs = expr.inputs.copy() @@ -1019,7 +1027,30 @@ def max_and_argmax(x: Reals[8]) -> Tuple[Real, Bint[8]]: return functools.partial(_function, inputs, output) -class Einsum(Funsor): +class EinsumOp(ops.Op, metaclass=ops.CachedOpMeta): + def __init__(self, equation): + self.equation = equation + + +@find_domain.register(EinsumOp) +def _find_domain_einsum(op, *operands): + equation = op.equation + ein_inputs, ein_output = equation.split("->") + ein_inputs = ein_inputs.split(",") + size_dict = {} + for ein_input, x in zip(ein_inputs, operands): + assert x.dtype == "real" + assert len(ein_input) == len(x.shape) + for name, size in zip(ein_input, x.shape): + other_size = size_dict.setdefault(name, size) + if other_size != size: + raise ValueError( + "Size mismatch at {}: {} vs {}".format(name, size, other_size) + ) + return Reals[tuple(size_dict[d] for d in ein_output)] + + +def Einsum(equation, *operands): """ Wrapper around :func:`torch.einsum` or :func:`np.einsum` to operate on real-valued Funsors. @@ -1030,70 +1061,39 @@ class Einsum(Funsor): :param str equation: An :func:`torch.einsum` or :func:`np.einsum` equation. :param tuple operands: A tuple of input funsors. """ - - def __init__(self, equation, operands): - assert isinstance(equation, str) - assert isinstance(operands, tuple) - assert all(isinstance(x, Funsor) for x in operands) - ein_inputs, ein_output = equation.split("->") - ein_inputs = ein_inputs.split(",") - size_dict = {} - inputs = OrderedDict() - assert len(ein_inputs) == len(operands) - for ein_input, x in zip(ein_inputs, operands): - assert x.dtype == "real" - inputs.update(x.inputs) - assert len(ein_input) == len(x.output.shape) - for name, size in zip(ein_input, x.output.shape): - other_size = size_dict.setdefault(name, size) - if other_size != size: - raise ValueError( - "Size mismatch at {}: {} vs {}".format(name, size, other_size) - ) - output = Reals[tuple(size_dict[d] for d in ein_output)] - super(Einsum, self).__init__(inputs, output) - self.equation = equation - self.operands = operands - - def __repr__(self): - return "Einsum({}, {})".format(repr(self.equation), repr(self.operands)) - - def __str__(self): - return "Einsum({}, {})".format(repr(self.equation), str(self.operands)) + return Finitary(EinsumOp(equation), tuple(operands)) -@eager.register(Einsum, str, tuple) -def eager_einsum(equation, operands): - if all(isinstance(x, Tensor) for x in operands): - # Make new symbols for inputs of operands. - inputs = OrderedDict() - for x in operands: - inputs.update(x.inputs) - symbols = set(equation) - get_symbol = iter(map(opt_einsum.get_symbol, itertools.count())) - new_symbols = {} - for k in inputs: +@eager.register(Finitary, EinsumOp, typing.Tuple[Tensor, ...]) +def eager_einsum(op, operands): + # Make new symbols for inputs of operands. + equation = op.equation + inputs = OrderedDict() + for x in operands: + inputs.update(x.inputs) + symbols = set(equation) + get_symbol = iter(map(opt_einsum.get_symbol, itertools.count())) + new_symbols = {} + for k in inputs: + symbol = next(get_symbol) + while symbol in symbols: symbol = next(get_symbol) - while symbol in symbols: - symbol = next(get_symbol) - symbols.add(symbol) - new_symbols[k] = symbol - - # Manually broadcast using einsum symbols. - assert "." not in equation - ins, out = equation.split("->") - ins = ins.split(",") - ins = [ - "".join(new_symbols[k] for k in x.inputs) + x_out - for x, x_out in zip(operands, ins) - ] - out = "".join(new_symbols[k] for k in inputs) + out - equation = ",".join(ins) + "->" + out + symbols.add(symbol) + new_symbols[k] = symbol - data = ops.einsum(equation, *[x.data for x in operands]) - return Tensor(data, inputs) + # Manually broadcast using einsum symbols. + assert "." not in equation + ins, out = equation.split("->") + ins = ins.split(",") + ins = [ + "".join(new_symbols[k] for k in x.inputs) + x_out + for x, x_out in zip(operands, ins) + ] + out = "".join(new_symbols[k] for k in inputs) + out + equation = ",".join(ins) + "->" + out - return None # defer to default implementation + data = ops.einsum(equation, *[x.data for x in operands]) + return Tensor(data, inputs) def tensordot(x, y, dims): @@ -1129,7 +1129,7 @@ def tensordot(x, y, dims): symbols[y_start:y_end], symbols[x_start:y_start] + symbols[x_end:y_end], ) - return Einsum(equation, (x, y)) + return Einsum(equation, x, y) def stack(parts, dim=0): diff --git a/funsor/terms.py b/funsor/terms.py index 6e3aca29c..cc872e280 100644 --- a/funsor/terms.py +++ b/funsor/terms.py @@ -1289,6 +1289,20 @@ def eager_binary_align_align(op, lhs, rhs): return Binary(op, lhs.arg, rhs.arg) +class Finitary(Funsor): + def __init__(self, op, args): + assert isinstance(op, ops.Op) + assert isinstance(args, tuple) + assert all(isinstance(v, Funsor) for v in args) + inputs = OrderedDict() + for arg in args: + inputs.update(arg.inputs) + output = find_domain(op, *(arg.output for arg in args)) + super().__init__(inputs, output) + self.op = op + self.args = args + + class Stack(Funsor): """ Stack of funsors along a new input dimension. diff --git a/test/test_affine.py b/test/test_affine.py index a991ea3c8..70e811099 100644 --- a/test/test_affine.py +++ b/test/test_affine.py @@ -9,7 +9,7 @@ from funsor.cnf import Contraction from funsor.domains import Bint, Real, Reals # noqa: F401 from funsor.tensor import Einsum, Tensor -from funsor.terms import Number, Unary, Variable +from funsor.terms import Finitary, Number, Unary, Variable from funsor.testing import ( assert_close, check_funsor, @@ -107,9 +107,9 @@ def test_affine_subs(expr, expected_type, expected_inputs): "Variable('x', Reals[2]) * randn(2) + ones(2)", "Variable('x', Reals[2]) + Tensor(randn(3, 2), OrderedDict(i=Bint[3]))", "Einsum('abcd,ac->bd'," - " (Tensor(randn(2, 3, 4, 5)), Variable('x', Reals[2, 4])))", + " Tensor(randn(2, 3, 4, 5)), Variable('x', Reals[2, 4]))", "Tensor(randn(3, 5)) + Einsum('abcd,ac->bd'," - " (Tensor(randn(2, 3, 4, 5)), Variable('x', Reals[2, 4])))", + " Tensor(randn(2, 3, 4, 5)), Variable('x', Reals[2, 4]))", "Variable('x', Reals[2, 8])[0] + randn(8)", "Variable('x', Reals[2, 8])[Variable('i', Bint[2])] / 4 - 3.5", ], @@ -117,7 +117,7 @@ def test_affine_subs(expr, expected_type, expected_inputs): def test_extract_affine(expr): x = eval(expr) assert is_affine(x) - assert isinstance(x, (Unary, Contraction, Einsum)) + assert isinstance(x, (Unary, Contraction, Finitary)) real_inputs = OrderedDict((k, d) for k, d in x.inputs.items() if d.dtype == "real") const, coeffs = extract_affine(x) @@ -134,7 +134,7 @@ def test_extract_affine(expr): assert isinstance(expected, Tensor) actual = const + sum( - Einsum(eqn, (coeff, subs[k])) for k, (coeff, eqn) in coeffs.items() + Einsum(eqn, coeff, subs[k]) for k, (coeff, eqn) in coeffs.items() ) assert isinstance(actual, Tensor) assert_close(actual, expected) @@ -157,7 +157,7 @@ def test_extract_affine(expr): "Variable('x', Reals[2,3]) @ Variable('y', Reals[3,4])", "random_gaussian(OrderedDict(x=Real))", "Einsum('abcd,ac->bd'," - " (Variable('y', Reals[2, 3, 4, 5]), Variable('x', Reals[2, 4])))", + " Variable('y', Reals[2, 3, 4, 5]), Variable('x', Reals[2, 4]))", ], ) def test_not_is_affine(expr): diff --git a/test/test_gaussian.py b/test/test_gaussian.py index d9e66af02..264953907 100644 --- a/test/test_gaussian.py +++ b/test/test_gaussian.py @@ -368,7 +368,7 @@ def test_eager_subs_variable(): ( ( "y", - 'Einsum("abc,bc->a", (Tensor(randn((4, 3, 5))), Variable("v", Reals[3, 5])))', + 'Einsum("abc,bc->a", Tensor(randn((4, 3, 5))), Variable("v", Reals[3, 5]))', ), ), ], diff --git a/test/test_tensor.py b/test/test_tensor.py index 2050fb105..5a6e62f1e 100644 --- a/test/test_tensor.py +++ b/test/test_tensor.py @@ -950,7 +950,7 @@ def test_einsum(equation): tensors = [randn(tuple(sizes[d] for d in dims)) for dims in inputs] funsors = [Tensor(x) for x in tensors] expected = Tensor(ops.einsum(equation, *tensors)) - actual = Einsum(equation, tuple(funsors)) + actual = Einsum(equation, *funsors) assert_close(actual, expected, atol=1e-5, rtol=None) @@ -968,7 +968,7 @@ def test_batched_einsum(equation, batch1, batch2): random_tensor(batch, Reals[tuple(sizes[d] for d in dims)]) for batch, dims in zip([batch1, batch2], inputs) ] - actual = Einsum(equation, tuple(funsors)) + actual = Einsum(equation, *funsors) _equation = ",".join("..." + i for i in inputs) + "->..." + output inputs, tensors = align_tensors(*funsors)