From 32bc76b4a2b6b457bfacbd00660adc232eaa6ed2 Mon Sep 17 00:00:00 2001 From: Gruschwick <40718240+Gruschwick@users.noreply.github.com> Date: Sun, 20 Jan 2019 20:04:08 +0100 Subject: [PATCH] Added special_math_ops.py file. --- tensorflow/python/ops/special_math_ops.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/tensorflow/python/ops/special_math_ops.py b/tensorflow/python/ops/special_math_ops.py index 21f4996798eda2..10b18a994439e9 100644 --- a/tensorflow/python/ops/special_math_ops.py +++ b/tensorflow/python/ops/special_math_ops.py @@ -236,12 +236,6 @@ def einsum(equation, *inputs, **kwargs): output_axis_labels = ''.join( sorted(ax for ax in indices if counts[ax] == 1)) - for a in axis_labels: - for input_labels in input_axis_labels: - if input_labels.count(a) > 1: - raise ValueError( - 'Subscript not supported: an axis appears more than once: %s' % - input_labels) for a in axis_labels: input_count = sum(1 for s in input_axis_labels if a in s) @@ -267,7 +261,12 @@ def einsum(equation, *inputs, **kwargs): i for i, a in enumerate(temp_axis_labels) if a not in output_axis_labels ] - temp = math_ops.reduce_sum(temp, axis=axis) + for a in axis_labels: + for input_labels in input_axis_labels: + if input_labels.count(a) == 2: + temp = math_ops.trace(temp) + else: + temp = math_ops.reduce_sum(temp, axis=axis) temp_axis_labels = ''.join( a for a in temp_axis_labels if a in output_axis_labels)