Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions funsor/affine.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@

import opt_einsum

from funsor.domains import Bint
from funsor.interpreter import gensym
from funsor.tensor import Einsum, Tensor, get_default_prototype
from funsor.terms import Binary, Bint, Funsor, Lambda, Reduce, Unary, Variable
from funsor.tensor import EinsumOp, Tensor, get_default_prototype
from funsor.terms import Binary, Finitary, Funsor, Lambda, Reduce, Unary, Variable

from . import ops

Expand Down Expand Up @@ -91,12 +92,12 @@ def _(fn):
return affine_inputs(fn.arg) - fn.reduced_vars


@affine_inputs.register(Einsum)
@affine_inputs.register(Finitary[EinsumOp, tuple])
def _(fn):
# This is simply a multiary version of the above Binary(ops.mul, ...) case.
results = []
for i, x in enumerate(fn.operands):
others = fn.operands[:i] + fn.operands[i + 1 :]
for i, x in enumerate(fn.args):
others = fn.args[:i] + fn.args[i + 1 :]
other_inputs = reduce(ops.or_, map(_real_inputs, others), frozenset())
results.append(affine_inputs(x) - other_inputs)
# This multilinear case introduces incompleteness, since some vars
Expand All @@ -114,7 +115,7 @@ def extract_affine(fn):

x = ...
const, coeffs = extract_affine(x)
y = sum(Einsum(eqn, (coeff, Variable(var, coeff.output)))
y = sum(Einsum(eqn, coeff, Variable(var, coeff.output))
for var, (coeff, eqn) in coeffs.items())
assert_close(y, x)
assert frozenset(coeffs) == affine_inputs(x)
Expand Down
122 changes: 61 additions & 61 deletions funsor/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from .ops import GetitemOp, MatmulOp, Op, ReshapeOp
from .terms import (
Binary,
Finitary,
Funsor,
FunsorMeta,
Lambda,
Expand Down Expand Up @@ -770,6 +771,13 @@ def eager_getitem_tensor_tensor(op, lhs, rhs):
return Tensor(data, inputs, lhs.dtype)


@eager.register(Finitary, Op, typing.Tuple[Tensor, ...])
def eager_finitary_generic_tensors(op, args):
inputs, raw_args = align_tensors(*args)
raw_result = op(*raw_args)
return Tensor(raw_result, inputs, args[0].dtype)


@eager.register(Lambda, Variable, Tensor)
def eager_lambda(var, expr):
inputs = expr.inputs.copy()
Expand Down Expand Up @@ -1019,7 +1027,30 @@ def max_and_argmax(x: Reals[8]) -> Tuple[Real, Bint[8]]:
return functools.partial(_function, inputs, output)


class Einsum(Funsor):
class EinsumOp(ops.Op, metaclass=ops.CachedOpMeta):
def __init__(self, equation):
self.equation = equation


@find_domain.register(EinsumOp)
def _find_domain_einsum(op, *operands):
equation = op.equation
ein_inputs, ein_output = equation.split("->")
ein_inputs = ein_inputs.split(",")
size_dict = {}
for ein_input, x in zip(ein_inputs, operands):
assert x.dtype == "real"
assert len(ein_input) == len(x.shape)
for name, size in zip(ein_input, x.shape):
other_size = size_dict.setdefault(name, size)
if other_size != size:
raise ValueError(
"Size mismatch at {}: {} vs {}".format(name, size, other_size)
)
return Reals[tuple(size_dict[d] for d in ein_output)]


def Einsum(equation, *operands):
"""
Wrapper around :func:`torch.einsum` or :func:`np.einsum` to operate on real-valued Funsors.

Expand All @@ -1030,70 +1061,39 @@ class Einsum(Funsor):
:param str equation: An :func:`torch.einsum` or :func:`np.einsum` equation.
:param tuple operands: A tuple of input funsors.
"""

def __init__(self, equation, operands):
assert isinstance(equation, str)
assert isinstance(operands, tuple)
assert all(isinstance(x, Funsor) for x in operands)
ein_inputs, ein_output = equation.split("->")
ein_inputs = ein_inputs.split(",")
size_dict = {}
inputs = OrderedDict()
assert len(ein_inputs) == len(operands)
for ein_input, x in zip(ein_inputs, operands):
assert x.dtype == "real"
inputs.update(x.inputs)
assert len(ein_input) == len(x.output.shape)
for name, size in zip(ein_input, x.output.shape):
other_size = size_dict.setdefault(name, size)
if other_size != size:
raise ValueError(
"Size mismatch at {}: {} vs {}".format(name, size, other_size)
)
output = Reals[tuple(size_dict[d] for d in ein_output)]
super(Einsum, self).__init__(inputs, output)
self.equation = equation
self.operands = operands

def __repr__(self):
return "Einsum({}, {})".format(repr(self.equation), repr(self.operands))

def __str__(self):
return "Einsum({}, {})".format(repr(self.equation), str(self.operands))
return Finitary(EinsumOp(equation), tuple(operands))


@eager.register(Einsum, str, tuple)
def eager_einsum(equation, operands):
if all(isinstance(x, Tensor) for x in operands):
# Make new symbols for inputs of operands.
inputs = OrderedDict()
for x in operands:
inputs.update(x.inputs)
symbols = set(equation)
get_symbol = iter(map(opt_einsum.get_symbol, itertools.count()))
new_symbols = {}
for k in inputs:
@eager.register(Finitary, EinsumOp, typing.Tuple[Tensor, ...])
def eager_einsum(op, operands):
# Make new symbols for inputs of operands.
equation = op.equation
inputs = OrderedDict()
for x in operands:
inputs.update(x.inputs)
symbols = set(equation)
get_symbol = iter(map(opt_einsum.get_symbol, itertools.count()))
new_symbols = {}
for k in inputs:
symbol = next(get_symbol)
while symbol in symbols:
symbol = next(get_symbol)
while symbol in symbols:
symbol = next(get_symbol)
symbols.add(symbol)
new_symbols[k] = symbol

# Manually broadcast using einsum symbols.
assert "." not in equation
ins, out = equation.split("->")
ins = ins.split(",")
ins = [
"".join(new_symbols[k] for k in x.inputs) + x_out
for x, x_out in zip(operands, ins)
]
out = "".join(new_symbols[k] for k in inputs) + out
equation = ",".join(ins) + "->" + out
symbols.add(symbol)
new_symbols[k] = symbol

data = ops.einsum(equation, *[x.data for x in operands])
return Tensor(data, inputs)
# Manually broadcast using einsum symbols.
assert "." not in equation
ins, out = equation.split("->")
ins = ins.split(",")
ins = [
"".join(new_symbols[k] for k in x.inputs) + x_out
for x, x_out in zip(operands, ins)
]
out = "".join(new_symbols[k] for k in inputs) + out
equation = ",".join(ins) + "->" + out

return None # defer to default implementation
data = ops.einsum(equation, *[x.data for x in operands])
return Tensor(data, inputs)


def tensordot(x, y, dims):
Expand Down Expand Up @@ -1129,7 +1129,7 @@ def tensordot(x, y, dims):
symbols[y_start:y_end],
symbols[x_start:y_start] + symbols[x_end:y_end],
)
return Einsum(equation, (x, y))
return Einsum(equation, x, y)


def stack(parts, dim=0):
Expand Down
14 changes: 14 additions & 0 deletions funsor/terms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1289,6 +1289,20 @@ def eager_binary_align_align(op, lhs, rhs):
return Binary(op, lhs.arg, rhs.arg)


class Finitary(Funsor):
def __init__(self, op, args):
assert isinstance(op, ops.Op)
assert isinstance(args, tuple)
assert all(isinstance(v, Funsor) for v in args)
inputs = OrderedDict()
for arg in args:
inputs.update(arg.inputs)
output = find_domain(op, *(arg.output for arg in args))
super().__init__(inputs, output)
self.op = op
self.args = args


class Stack(Funsor):
"""
Stack of funsors along a new input dimension.
Expand Down
12 changes: 6 additions & 6 deletions test/test_affine.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from funsor.cnf import Contraction
from funsor.domains import Bint, Real, Reals # noqa: F401
from funsor.tensor import Einsum, Tensor
from funsor.terms import Number, Unary, Variable
from funsor.terms import Finitary, Number, Unary, Variable
from funsor.testing import (
assert_close,
check_funsor,
Expand Down Expand Up @@ -107,17 +107,17 @@ def test_affine_subs(expr, expected_type, expected_inputs):
"Variable('x', Reals[2]) * randn(2) + ones(2)",
"Variable('x', Reals[2]) + Tensor(randn(3, 2), OrderedDict(i=Bint[3]))",
"Einsum('abcd,ac->bd',"
" (Tensor(randn(2, 3, 4, 5)), Variable('x', Reals[2, 4])))",
" Tensor(randn(2, 3, 4, 5)), Variable('x', Reals[2, 4]))",
"Tensor(randn(3, 5)) + Einsum('abcd,ac->bd',"
" (Tensor(randn(2, 3, 4, 5)), Variable('x', Reals[2, 4])))",
" Tensor(randn(2, 3, 4, 5)), Variable('x', Reals[2, 4]))",
"Variable('x', Reals[2, 8])[0] + randn(8)",
"Variable('x', Reals[2, 8])[Variable('i', Bint[2])] / 4 - 3.5",
],
)
def test_extract_affine(expr):
x = eval(expr)
assert is_affine(x)
assert isinstance(x, (Unary, Contraction, Einsum))
assert isinstance(x, (Unary, Contraction, Finitary))
real_inputs = OrderedDict((k, d) for k, d in x.inputs.items() if d.dtype == "real")

const, coeffs = extract_affine(x)
Expand All @@ -134,7 +134,7 @@ def test_extract_affine(expr):
assert isinstance(expected, Tensor)

actual = const + sum(
Einsum(eqn, (coeff, subs[k])) for k, (coeff, eqn) in coeffs.items()
Einsum(eqn, coeff, subs[k]) for k, (coeff, eqn) in coeffs.items()
)
assert isinstance(actual, Tensor)
assert_close(actual, expected)
Expand All @@ -157,7 +157,7 @@ def test_extract_affine(expr):
"Variable('x', Reals[2,3]) @ Variable('y', Reals[3,4])",
"random_gaussian(OrderedDict(x=Real))",
"Einsum('abcd,ac->bd',"
" (Variable('y', Reals[2, 3, 4, 5]), Variable('x', Reals[2, 4])))",
" Variable('y', Reals[2, 3, 4, 5]), Variable('x', Reals[2, 4]))",
],
)
def test_not_is_affine(expr):
Expand Down
2 changes: 1 addition & 1 deletion test/test_gaussian.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,7 @@ def test_eager_subs_variable():
(
(
"y",
'Einsum("abc,bc->a", (Tensor(randn((4, 3, 5))), Variable("v", Reals[3, 5])))',
'Einsum("abc,bc->a", Tensor(randn((4, 3, 5))), Variable("v", Reals[3, 5]))',
),
),
],
Expand Down
4 changes: 2 additions & 2 deletions test/test_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -950,7 +950,7 @@ def test_einsum(equation):
tensors = [randn(tuple(sizes[d] for d in dims)) for dims in inputs]
funsors = [Tensor(x) for x in tensors]
expected = Tensor(ops.einsum(equation, *tensors))
actual = Einsum(equation, tuple(funsors))
actual = Einsum(equation, *funsors)
assert_close(actual, expected, atol=1e-5, rtol=None)


Expand All @@ -968,7 +968,7 @@ def test_batched_einsum(equation, batch1, batch2):
random_tensor(batch, Reals[tuple(sizes[d] for d in dims)])
for batch, dims in zip([batch1, batch2], inputs)
]
actual = Einsum(equation, tuple(funsors))
actual = Einsum(equation, *funsors)

_equation = ",".join("..." + i for i in inputs) + "->..." + output
inputs, tensors = align_tensors(*funsors)
Expand Down