|  | 
|  | 1 | +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. | 
|  | 2 | +# | 
|  | 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); | 
|  | 4 | +# you may not use this file except in compliance with the License. | 
|  | 5 | +# You may obtain a copy of the License at | 
|  | 6 | +# | 
|  | 7 | +#     http://www.apache.org/licenses/LICENSE-2.0 | 
|  | 8 | +# | 
|  | 9 | +# Unless required by applicable law or agreed to in writing, software | 
|  | 10 | +# distributed under the License is distributed on an "AS IS" BASIS, | 
|  | 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | 
|  | 12 | +# See the License for the specific language governing permissions and | 
|  | 13 | +# limitations under the License. | 
|  | 14 | +# ============================================================================== | 
|  | 15 | + | 
|  | 16 | + | 
|  | 17 | +import tensorflow as tf | 
|  | 18 | +from tf_helpers import lax | 
|  | 19 | +from tensorflow.python.platform import test | 
|  | 20 | +from absl.testing import parameterized | 
|  | 21 | +import itertools | 
|  | 22 | +import numpy as onp | 
|  | 23 | +from tensorflow.python.ops import numpy_ops as tfnp | 
|  | 24 | +from jax import numpy as jnp | 
|  | 25 | +import jax | 
|  | 26 | +import sys | 
|  | 27 | + | 
|  | 28 | + | 
|  | 29 | +class TFLaxTest(tf.test.TestCase, parameterized.TestCase): | 
|  | 30 | + | 
|  | 31 | +  @parameterized.parameters( | 
|  | 32 | +    {"lhs_np": onp.ones((5, 3)), "rhs_np": onp.ones((3, 2)), | 
|  | 33 | +      "dims": (((1,), (0,)), ((), ()))}, | 
|  | 34 | +    {"lhs_np": onp.ones((5, 3)), "rhs_np": onp.ones((5, 3)), | 
|  | 35 | +      "dims": (((0, 1), (0, 1)), ((), ()))}, | 
|  | 36 | +    {"lhs_np": onp.ones((5, 3, 2)), "rhs_np": onp.ones((2, 3, 2)), | 
|  | 37 | +      "dims": (((1, 2), (1, 0)), ((), ()))}, | 
|  | 38 | +    {"lhs_np": onp.ones((6, 5, 3)), "rhs_np": onp.ones((6, 3, 2)), | 
|  | 39 | +      "dims": (((2,), (1,)), ((0,), (0,)))}, | 
|  | 40 | +    {"lhs_np": onp.ones((6, 3, 5)), "rhs_np": onp.ones((6, 3, 2)), | 
|  | 41 | +      "dims": (((1,), (1,)), ((0,), (0,)))}, | 
|  | 42 | +    {"lhs_np": onp.ones((5, 3, 2, 2)), "rhs_np": onp.ones((5, 2, 2, 6)), | 
|  | 43 | +      "dims": (((2, 3), (1, 2)), ((0,), (0,)))}, | 
|  | 44 | +    {"lhs_np": onp.ones((2, 2, 5, 3)), "rhs_np": onp.ones((2, 2, 3, 2)), | 
|  | 45 | +      "dims": (((3,), (2,)), ((0, 1), (0, 1)))}, | 
|  | 46 | +    {"lhs_np": onp.ones((2, 2, 5, 2)), "rhs_np": onp.ones((2, 2, 3, 2)), | 
|  | 47 | +      "dims": (((3,), (1,)), ((0,), (0,)))}, | 
|  | 48 | +    {"lhs_np": onp.ones((2, 2, 5, 3, 3)), "rhs_np": onp.ones((2, 3, 2, 3, 2)), | 
|  | 49 | +      "dims": (((4,), (1,)), ((0,), (0,)))}, | 
|  | 50 | +  ) | 
|  | 51 | +  def test_tf_dot_general(self, lhs_np, rhs_np, dims): | 
|  | 52 | +    ans = jax.lax.dot_general(lhs_np, rhs_np, dims) | 
|  | 53 | +    result = lax.dot_general(lhs_np, rhs_np, dims) | 
|  | 54 | +    self.assertAllClose(result, tfnp.array(ans)) | 
|  | 55 | + | 
|  | 56 | +  @parameterized.named_parameters([ | 
|  | 57 | +      ("_lhs_shape={}_rhs_shape={}_strides={}_padding={}" | 
|  | 58 | +       "_lhs_dilation={}_rhs_dilation={}" | 
|  | 59 | +       "_feature_group_count={}_batch_group_count={}_dims={}" | 
|  | 60 | +       "_perms={}".format(lhs_shape, rhs_shape, | 
|  | 61 | +           strides, padding, lhs_dilation, rhs_dilation, | 
|  | 62 | +           feature_group_count, batch_group_count, ",".join(dimension_numbers), perms), | 
|  | 63 | +           lhs_shape, rhs_shape, strides, padding, lhs_dilation, rhs_dilation, | 
|  | 64 | +           feature_group_count, batch_group_count, dimension_numbers, perms) | 
|  | 65 | +      for batch_group_count, feature_group_count in [(1, 1)] | 
|  | 66 | +      for lhs_shape, rhs_shape in [ | 
|  | 67 | +          ((b * batch_group_count, i * feature_group_count, 9, w), | 
|  | 68 | +           (j * feature_group_count * batch_group_count, i, 4, 5)) | 
|  | 69 | +          for w in [0, 10] | 
|  | 70 | +          for b, i, j in itertools.product([2, 3], repeat=3)] | 
|  | 71 | +      for strides in [(1, 1), (2, 1)] | 
|  | 72 | +      for padding in ['SAME'] | 
|  | 73 | +      for lhs_dilation, rhs_dilation in [ | 
|  | 74 | +        (None, (1, 1)) | 
|  | 75 | +      ] | 
|  | 76 | +      for dimension_numbers, perms in [ | 
|  | 77 | +        (("NHWC", "HWIO", "NHWC"), ([0, 2, 3, 1], [2, 3, 1, 0])) | 
|  | 78 | +      ]]) | 
|  | 79 | +  def testConvGeneralDilated(self, lhs_shape, rhs_shape, strides, | 
|  | 80 | +                             padding, lhs_dilation, rhs_dilation, | 
|  | 81 | +                             feature_group_count, batch_group_count, | 
|  | 82 | +                             dimension_numbers, perms): | 
|  | 83 | +    tf.print("dimension_numbers: {}".format(dimension_numbers), output_stream=sys.stdout) | 
|  | 84 | +    lhs_perm, rhs_perm = perms  # permute to compatible shapes | 
|  | 85 | + | 
|  | 86 | +    lhs_tf = tfnp.transpose(tfnp.ones(lhs_shape), lhs_perm) | 
|  | 87 | +    rhs_tf = tfnp.transpose(tfnp.ones(rhs_shape), rhs_perm) | 
|  | 88 | + | 
|  | 89 | +    lhs_jax = jnp.transpose(jnp.ones(lhs_shape), lhs_perm) | 
|  | 90 | +    rhs_jax = jnp.transpose(jnp.ones(rhs_shape), rhs_perm) | 
|  | 91 | + | 
|  | 92 | +    jax_conv = jax.lax.conv_general_dilated(lhs_jax, rhs_jax, strides, padding, lhs_dilation, | 
|  | 93 | +      rhs_dilation, dimension_numbers, feature_group_count, batch_group_count) | 
|  | 94 | + | 
|  | 95 | +    tf_conv = lax.conv_general_dilated(lhs_tf, rhs_tf, strides, padding, jax_conv.shape, lhs_dilation, | 
|  | 96 | +      rhs_dilation, dimension_numbers, feature_group_count, batch_group_count) | 
|  | 97 | + | 
|  | 98 | +    self.assertAllEqual(tf_conv, tfnp.asarray(jax_conv)) | 
|  | 99 | + | 
|  | 100 | + | 
|  | 101 | +if __name__ == "__main__": | 
|  | 102 | +  test.main() | 
0 commit comments