diff --git a/com.ibm.wala.cast.python.ml.test/source/com/ibm/wala/cast/python/ml/test/TestTensorflow2Model.java b/com.ibm.wala.cast.python.ml.test/source/com/ibm/wala/cast/python/ml/test/TestTensorflow2Model.java index 1e837c17f..632fa03ea 100644 --- a/com.ibm.wala.cast.python.ml.test/source/com/ibm/wala/cast/python/ml/test/TestTensorflow2Model.java +++ b/com.ibm.wala.cast.python.ml.test/source/com/ibm/wala/cast/python/ml/test/TestTensorflow2Model.java @@ -1,5 +1,7 @@ package com.ibm.wala.cast.python.ml.test; +import static com.ibm.wala.cast.python.ml.types.TensorFlowTypes.DType.FLOAT32; +import static com.ibm.wala.cast.python.ml.types.TensorFlowTypes.DType.INT32; import static com.ibm.wala.cast.python.ml.types.TensorType.mnistInput; import static com.ibm.wala.cast.python.util.Util.addPytestEntrypoints; import static java.util.Arrays.asList; @@ -56,8 +58,61 @@ public class TestTensorflow2Model extends TestPythonMLCallGraphShape { private static final Logger LOGGER = Logger.getLogger(TestTensorflow2Model.class.getName()); + private static final String FLOAT_32 = FLOAT32.name().toLowerCase(); + + private static final String INT_32 = INT32.name().toLowerCase(); + private static final TensorType MNIST_INPUT = mnistInput(); + private static final TensorType SCALAR_TENSOR_OF_INT32 = new TensorType(INT_32, emptyList()); + + private static final TensorType SCALAR_TENSOR_OF_FLOAT32 = new TensorType(FLOAT_32, emptyList()); + + private static final TensorType TENSOR_1_2_FLOAT32 = + new TensorType(FLOAT_32, asList(new NumericDim(1), new NumericDim(2))); + + private static final TensorType TENSOR_2_2_FLOAT32 = + new TensorType(FLOAT_32, asList(new NumericDim(2), new NumericDim(2))); + + private static final TensorType TENSOR_3_2_FLOAT32 = + new TensorType(FLOAT_32, asList(new NumericDim(3), new NumericDim(2))); + + private static final TensorType TENSOR_2_1_FLOAT32 = + new TensorType(FLOAT_32, asList(new NumericDim(2), new NumericDim(1))); + + private static final TensorType TENSOR_2_3_3_FLOAT32 = + new TensorType(FLOAT_32, asList(new NumericDim(2), new NumericDim(3), new NumericDim(3))); + + private static final TensorType TENSOR_2_3_3_INT32 = + new TensorType(INT_32, asList(new NumericDim(2), new NumericDim(3), new NumericDim(3))); + + private static final TensorType TENSOR_2_3_4_FLOAT32 = + new TensorType(FLOAT_32, asList(new NumericDim(2), new NumericDim(3), new NumericDim(4))); + + private static final TensorType TENSOR_2_3_4_INT32 = + new TensorType(INT_32, asList(new NumericDim(2), new NumericDim(3), new NumericDim(4))); + + private static final TensorType TENSOR_2_FLOAT32 = + new TensorType(FLOAT_32, asList(new NumericDim(2))); + + private static final TensorType TENSOR_2_INT32 = + new TensorType(INT_32, asList(new NumericDim(2))); + + private static final TensorType TENSOR_3_INT32 = + new TensorType(INT_32, asList(new NumericDim(3))); + + private static final TensorType TENSOR_3_FLOAT32 = + new TensorType(FLOAT_32, asList(new NumericDim(3))); + + private static final TensorType TENSOR_4_FLOAT32 = + new TensorType(FLOAT_32, asList(new NumericDim(4))); + + private static final TensorType TENSOR_5_FLOAT32 = + new TensorType(FLOAT_32, asList(new NumericDim(5))); + + private static final TensorType TENSOR_5_INT32 = + new TensorType(INT_32, asList(new NumericDim(5))); + @Test public void testValueIndex() throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException { @@ -126,70 +181,156 @@ public void testFunction4() test("tf2_test_function4.py", "func2", 1, 1, Map.of(2, Set.of(MNIST_INPUT))); } + @Test + public void testFunction5() + throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException { + test("tf2_test_function5.py", "func", 1, 1, Map.of(2, Set.of(TENSOR_2_2_FLOAT32))); + } + + @Test + public void testFunction6() + throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException { + test("tf2_test_function6.py", "func", 1, 1, Map.of(2, Set.of(TENSOR_2_1_FLOAT32))); + } + + @Test + public void testFunction7() + throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException { + test("tf2_test_function7.py", "func", 1, 1, Map.of(2, Set.of(TENSOR_2_FLOAT32))); + } + + @Test + public void testFunction8() + throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException { + test( + "tf2_test_function8.py", + "func", + 1, + 1, + Map.of(2, Set.of(TENSOR_2_1_FLOAT32, TENSOR_2_FLOAT32))); + } + + @Test + public void testFunction9() + throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException { + test("tf2_test_function9.py", "func", 1, 1, Map.of(2, Set.of(TENSOR_1_2_FLOAT32))); + } + + @Test + public void testFunction10() + throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException { + test("tf2_test_function10.py", "func", 1, 1, Map.of(2, Set.of(TENSOR_2_3_4_INT32))); + } + + @Test + public void testFunction11() + throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException { + test("tf2_test_function11.py", "func", 1, 1, Map.of(2, Set.of(TENSOR_2_3_3_INT32))); + } + + /** Test https://github.com/wala/ML/issues/308. */ + @Test + public void testFunction12() + throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException { + test( + "tf2_test_function12.py", + "func", + 1, + 1, + Map.of(2, Set.of(TENSOR_2_1_FLOAT32, TENSOR_3_2_FLOAT32))); + } + + /** + * Test https://github.com/wala/ML/issues/308. + * + *

This one has lexical scoping. + */ + @Test + public void testFunction13() + throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException { + test( + "tf2_test_function13.py", + "func", + 1, + 1, + Map.of(2, Set.of(TENSOR_2_1_FLOAT32, TENSOR_3_2_FLOAT32))); + } + @Test public void testDecorator() throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException { - test("tf2_test_decorator.py", "returned", 1, 1, Map.of(2, Set.of(MNIST_INPUT))); + test("tf2_test_decorator.py", "returned", 1, 1, Map.of(2, Set.of(TENSOR_5_INT32))); } @Test public void testDecorator2() throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException { - test("tf2_test_decorator2.py", "returned", 1, 1, Map.of(2, Set.of(MNIST_INPUT))); + test("tf2_test_decorator2.py", "returned", 1, 1, Map.of(2, Set.of(TENSOR_5_INT32))); } @Test public void testDecorator3() throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException { - test("tf2_test_decorator3.py", "returned", 1, 1, Map.of(2, Set.of(MNIST_INPUT))); + test("tf2_test_decorator3.py", "returned", 1, 1, Map.of(2, Set.of(TENSOR_2_FLOAT32))); } @Test public void testDecorator4() throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException { - test("tf2_test_decorator4.py", "returned", 1, 1, Map.of(2, Set.of(MNIST_INPUT))); + test("tf2_test_decorator4.py", "returned", 1, 1, Map.of(2, Set.of(TENSOR_5_INT32))); } @Test public void testDecorator5() throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException { - test("tf2_test_decorator5.py", "returned", 1, 1, Map.of(2, Set.of(MNIST_INPUT))); + test("tf2_test_decorator5.py", "returned", 1, 1, Map.of(2, Set.of(TENSOR_5_INT32))); } @Test public void testDecorator6() throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException { - test("tf2_test_decorator6.py", "returned", 1, 1, Map.of(2, Set.of(MNIST_INPUT))); + test("tf2_test_decorator6.py", "returned", 1, 1, Map.of(2, Set.of(TENSOR_5_INT32))); } @Test public void testDecorator7() throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException { - test("tf2_test_decorator7.py", "returned", 1, 1, Map.of(2, Set.of(MNIST_INPUT))); + test("tf2_test_decorator7.py", "returned", 1, 1, Map.of(2, Set.of(TENSOR_5_INT32))); } @Test public void testDecorator8() throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException { - test("tf2_test_decorator8.py", "returned", 1, 1, Map.of(2, Set.of(MNIST_INPUT))); + test("tf2_test_decorator8.py", "returned", 1, 1, Map.of(2, Set.of(TENSOR_5_INT32))); } @Test public void testDecorator9() throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException { - test("tf2_test_decorator9.py", "returned", 1, 1, Map.of(2, Set.of(MNIST_INPUT))); + test("tf2_test_decorator9.py", "returned", 1, 1, Map.of(2, Set.of(TENSOR_5_INT32))); } @Test public void testDecorator10() throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException { - test("tf2_test_decorator10.py", "returned", 1, 1, Map.of(2, Set.of(MNIST_INPUT))); + test("tf2_test_decorator10.py", "returned", 1, 1, Map.of(2, Set.of(TENSOR_5_INT32))); } @Test public void testDecorator11() throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException { - test("tf2_test_decorator11.py", "C.returned", 1, 1, Map.of(3, Set.of(MNIST_INPUT))); + test("tf2_test_decorator11.py", "C.returned", 1, 1, Map.of(3, Set.of(TENSOR_5_INT32))); + } + + @Test + public void testDecorator12() + throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException { + test( + "tf2_test_decorator12.py", + "returned", + 1, + 1, + Map.of(2, Set.of(TENSOR_2_FLOAT32, TENSOR_2_INT32))); } @Test @@ -582,7 +723,11 @@ public void testTensorList() "add", 2, 2, - Map.of(2, Set.of(MNIST_INPUT), 3, Set.of(MNIST_INPUT))); + Map.of( + 2, + Set.of(TENSOR_1_2_FLOAT32, TENSOR_2_2_FLOAT32), + 3, + Set.of(TENSOR_1_2_FLOAT32, TENSOR_2_2_FLOAT32))); } @Test @@ -717,13 +862,13 @@ public void testModelAttributes6() @Test public void testCallbacks() throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException { - test("tf2_test_callbacks.py", "replica_fn", 1, 1, Map.of(2, Set.of(MNIST_INPUT))); + test("tf2_test_callbacks.py", "replica_fn", 1, 1, Map.of(2, Set.of(SCALAR_TENSOR_OF_FLOAT32))); } @Test public void testCallbacks2() throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException { - test("tf2_test_callbacks2.py", "replica_fn", 1, 1, Map.of(2, Set.of(MNIST_INPUT))); + test("tf2_test_callbacks2.py", "replica_fn", 1, 1, Map.of(2, Set.of(SCALAR_TENSOR_OF_FLOAT32))); } @Test @@ -822,13 +967,13 @@ public void testAutoencoder4() @Test public void testSigmoid() throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException { - test("tf2_test_sigmoid.py", "f", 1, 1, Map.of(2, Set.of(MNIST_INPUT))); + test("tf2_test_sigmoid.py", "f", 1, 1, Map.of(2, Set.of(TENSOR_4_FLOAT32))); } @Test public void testSigmoid2() throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException { - test("tf2_test_sigmoid2.py", "f", 1, 1, Map.of(2, Set.of(MNIST_INPUT))); + test("tf2_test_sigmoid2.py", "f", 1, 1, Map.of(2, Set.of(TENSOR_4_FLOAT32))); } @Test @@ -870,19 +1015,34 @@ public void testAdd6() @Test public void testAdd7() throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException { - test("tf2_test_add7.py", "add", 2, 2, Map.of(2, Set.of(MNIST_INPUT), 3, Set.of(MNIST_INPUT))); + test( + "tf2_test_add7.py", + "add", + 2, + 2, + Map.of(2, Set.of(TENSOR_1_2_FLOAT32), 3, Set.of(TENSOR_2_2_FLOAT32))); } @Test public void testAdd8() throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException { - test("tf2_test_add8.py", "add", 2, 2, Map.of(2, Set.of(MNIST_INPUT), 3, Set.of(MNIST_INPUT))); + test( + "tf2_test_add8.py", + "add", + 2, + 2, + Map.of(2, Set.of(TENSOR_1_2_FLOAT32), 3, Set.of(TENSOR_2_2_FLOAT32))); } @Test public void testAdd9() throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException { - test("tf2_test_add9.py", "add", 2, 2, Map.of(2, Set.of(MNIST_INPUT), 3, Set.of(MNIST_INPUT))); + test( + "tf2_test_add9.py", + "add", + 2, + 2, + Map.of(2, Set.of(TENSOR_1_2_FLOAT32), 3, Set.of(TENSOR_2_2_FLOAT32))); } @Test @@ -966,19 +1126,34 @@ public void testAdd22() @Test public void testAdd23() throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException { - test("tf2_test_add23.py", "add", 2, 2, Map.of(2, Set.of(MNIST_INPUT), 3, Set.of(MNIST_INPUT))); + test( + "tf2_test_add23.py", + "add", + 2, + 2, + Map.of(2, Set.of(TENSOR_2_INT32), 3, Set.of(TENSOR_2_INT32))); } @Test public void testAdd24() throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException { - test("tf2_test_add24.py", "add", 2, 2, Map.of(2, Set.of(MNIST_INPUT), 3, Set.of(MNIST_INPUT))); + test( + "tf2_test_add24.py", + "add", + 2, + 2, + Map.of(2, Set.of(TENSOR_2_INT32), 3, Set.of(TENSOR_2_INT32))); } @Test public void testAdd25() throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException { - test("tf2_test_add25.py", "add", 2, 2, Map.of(2, Set.of(MNIST_INPUT), 3, Set.of(MNIST_INPUT))); + test( + "tf2_test_add25.py", + "add", + 2, + 2, + Map.of(2, Set.of(TENSOR_2_INT32), 3, Set.of(TENSOR_2_INT32))); } @Test @@ -1126,7 +1301,12 @@ public void testAdd47() @Test public void testAdd48() throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException { - test("tf2_test_add48.py", "add", 2, 2, Map.of(2, Set.of(MNIST_INPUT), 3, Set.of(MNIST_INPUT))); + test( + "tf2_test_add48.py", + "add", + 2, + 2, + Map.of(2, Set.of(TENSOR_1_2_FLOAT32), 3, Set.of(TENSOR_2_2_FLOAT32))); } @Test @@ -1134,7 +1314,12 @@ public void testAdd49() throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException { // NOTE: Set the expected number of tensor variables to 3 once // https://github.com/wala/ML/issues/135 is fixed. - test("tf2_test_add49.py", "add", 2, 2, Map.of(2, Set.of(MNIST_INPUT), 3, Set.of(MNIST_INPUT))); + test( + "tf2_test_add49.py", + "add", + 2, + 2, + Map.of(2, Set.of(TENSOR_1_2_FLOAT32), 3, Set.of(TENSOR_2_2_FLOAT32))); } @Test @@ -1533,6 +1718,34 @@ public void testAdd115() test("tf2_test_add115.py", "add", 2, 2, Map.of(2, Set.of(MNIST_INPUT), 3, Set.of(MNIST_INPUT))); } + @Test + public void testAdd116() + throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException { + test( + "tf2_test_add116.py", + "add", + 2, + 2, + Map.of(2, Set.of(TENSOR_1_2_FLOAT32), 3, Set.of(TENSOR_2_2_FLOAT32))); + } + + @Test + public void testAdd117() + throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException { + test( + "tf2_test_add117.py", + "add", + 2, + 2, + Map.of( + 2, + Set.of( + TENSOR_1_2_FLOAT32, + new TensorType(FLOAT_32, asList(new NumericDim(3), new NumericDim(2)))), + 3, + Set.of(TENSOR_2_2_FLOAT32))); + } + @Test public void testMultiGPUTraining() throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException { @@ -1557,19 +1770,19 @@ public void testMultiGPUTraining2() @Test public void testReduceMean() throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException { - test("tf2_test_reduce_mean.py", "f", 1, 1, Map.of(2, Set.of(MNIST_INPUT))); + test("tf2_test_reduce_mean.py", "f", 1, 1, Map.of(2, Set.of(TENSOR_2_2_FLOAT32))); } @Test public void testReduceMean2() throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException { - test("tf2_test_reduce_mean.py", "g", 1, 1, Map.of(2, Set.of(MNIST_INPUT))); + test("tf2_test_reduce_mean.py", "g", 1, 1, Map.of(2, Set.of(TENSOR_2_2_FLOAT32))); } @Test public void testReduceMean3() throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException { - test("tf2_test_reduce_mean.py", "h", 1, 1, Map.of(2, Set.of(MNIST_INPUT))); + test("tf2_test_reduce_mean.py", "h", 1, 1, Map.of(2, Set.of(TENSOR_2_2_FLOAT32))); } @Test @@ -1604,13 +1817,13 @@ public void testSparseSoftmaxCrossEntropyWithLogits() "f", 1, 1, - Map.of(2, Set.of(MNIST_INPUT))); + Map.of(2, Set.of(TENSOR_3_INT32))); } @Test public void testRelu() throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException { - test("tf2_test_relu.py", "f", 1, 1, Map.of(2, Set.of(MNIST_INPUT))); + test("tf2_test_relu.py", "f", 1, 1, Map.of(2, Set.of(TENSOR_3_FLOAT32))); } @Test @@ -1622,7 +1835,7 @@ public void testTFRange() @Test public void testTFRange2() throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException { - test("tf2_test_tf_range2.py", "f", 1, 1, Map.of(2, Set.of(MNIST_INPUT))); + test("tf2_test_tf_range2.py", "f", 1, 1, Map.of(2, Set.of(SCALAR_TENSOR_OF_INT32))); } @Test @@ -1634,41 +1847,41 @@ public void testTFRange3() @Test public void testImport() throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException { - test("tf2_test_import.py", "f", 1, 1, Map.of(2, Set.of(MNIST_INPUT))); + test("tf2_test_import.py", "f", 1, 1, Map.of(2, Set.of(TENSOR_1_2_FLOAT32))); } @Test public void testImport2() throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException { - test("tf2_test_import2.py", "f", 1, 1, Map.of(2, Set.of(MNIST_INPUT))); + test("tf2_test_import2.py", "f", 1, 1, Map.of(2, Set.of(TENSOR_1_2_FLOAT32))); } @Test public void testImport3() throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException { - test("tf2_test_import3.py", "f", 1, 2, Map.of(2, Set.of(MNIST_INPUT))); - test("tf2_test_import3.py", "g", 1, 1, Map.of(2, Set.of(MNIST_INPUT))); + test("tf2_test_import3.py", "f", 1, 2, Map.of(2, Set.of(TENSOR_1_2_FLOAT32))); + test("tf2_test_import3.py", "g", 1, 1, Map.of(2, Set.of(TENSOR_1_2_FLOAT32))); } @Test public void testImport4() throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException { - test("tf2_test_import4.py", "f", 1, 2, Map.of(2, Set.of(MNIST_INPUT))); - test("tf2_test_import4.py", "g", 1, 1, Map.of(2, Set.of(MNIST_INPUT))); + test("tf2_test_import4.py", "f", 1, 2, Map.of(2, Set.of(TENSOR_1_2_FLOAT32))); + test("tf2_test_import4.py", "g", 1, 1, Map.of(2, Set.of(TENSOR_1_2_FLOAT32))); } @Test public void testImport5() throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException { test("tf2_test_import5.py", "f", 0, 1); - test("tf2_test_import5.py", "g", 1, 1, Map.of(2, Set.of(MNIST_INPUT))); + test("tf2_test_import5.py", "g", 1, 1, Map.of(2, Set.of(TENSOR_1_2_FLOAT32))); } @Test public void testImport6() throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException { test("tf2_test_import6.py", "f", 0, 1); - test("tf2_test_import6.py", "g", 1, 1, Map.of(2, Set.of(MNIST_INPUT))); + test("tf2_test_import6.py", "g", 1, 1, Map.of(2, Set.of(TENSOR_1_2_FLOAT32))); } /** @@ -1696,8 +1909,8 @@ public void testImport8() @Test public void testImport9() throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException { - test("tf2_test_import9.py", "f", 1, 1, Map.of(2, Set.of(MNIST_INPUT))); - test("tf2_test_import9.py", "g", 1, 1, Map.of(2, Set.of(MNIST_INPUT))); + test("tf2_test_import9.py", "f", 1, 1, Map.of(2, Set.of(TENSOR_1_2_FLOAT32))); + test("tf2_test_import9.py", "g", 1, 1, Map.of(2, Set.of(TENSOR_1_2_FLOAT32))); } @Test @@ -1710,7 +1923,7 @@ public void testModule() "", 1, 1, - Map.of(2, Set.of(MNIST_INPUT))); + Map.of(2, Set.of(TENSOR_1_2_FLOAT32))); } /** This test needs a PYTHONPATH that points to `proj`. */ @@ -1726,7 +1939,7 @@ public void testModule2() "proj", 1, 1, - Map.of(2, Set.of(MNIST_INPUT))); + Map.of(2, Set.of(TENSOR_1_2_FLOAT32))); } /** This test should not need a PYTHONPATH. */ @@ -1742,7 +1955,7 @@ public void testModule3() "proj2", 1, 1, - Map.of(2, Set.of(MNIST_INPUT))); + Map.of(2, Set.of(TENSOR_1_2_FLOAT32))); } /** @@ -1764,7 +1977,7 @@ public void testModule4() "proj3", 1, 1, - Map.of(2, Set.of(MNIST_INPUT))); + Map.of(2, Set.of(TENSOR_1_2_FLOAT32))); test( new String[] { @@ -1778,7 +1991,7 @@ public void testModule4() "proj3", 1, 1, - Map.of(2, Set.of(MNIST_INPUT))); + Map.of(2, Set.of(TENSOR_1_2_FLOAT32))); } @Test @@ -1791,7 +2004,7 @@ public void testModule5() "", 1, 1, - Map.of(3, Set.of(MNIST_INPUT))); + Map.of(3, Set.of(TENSOR_1_2_FLOAT32))); } /** This test needs a PYTHONPATH that points to `proj4`. */ @@ -1807,7 +2020,7 @@ public void testModule6() "proj4", 1, 1, - Map.of(3, Set.of(MNIST_INPUT))); + Map.of(3, Set.of(TENSOR_1_2_FLOAT32))); } /** This test should not need a PYTHONPATH. */ @@ -1823,7 +2036,7 @@ public void testModule7() "proj5", 1, 1, - Map.of(3, Set.of(MNIST_INPUT))); + Map.of(3, Set.of(TENSOR_1_2_FLOAT32))); } /** @@ -1845,7 +2058,7 @@ public void testModule8() "proj6", 1, 1, - Map.of(3, Set.of(MNIST_INPUT))); + Map.of(3, Set.of(TENSOR_1_2_FLOAT32))); test( new String[] { @@ -1859,7 +2072,7 @@ public void testModule8() "proj6", 1, 1, - Map.of(3, Set.of(MNIST_INPUT))); + Map.of(3, Set.of(TENSOR_1_2_FLOAT32))); } @Test @@ -1872,7 +2085,7 @@ public void testModule9() "", 1, 1, - Map.of(3, Set.of(MNIST_INPUT))); + Map.of(3, Set.of(TENSOR_1_2_FLOAT32))); } @Test @@ -1885,7 +2098,7 @@ public void testModule10() "", 1, 1, - Map.of(3, Set.of(MNIST_INPUT))); + Map.of(3, Set.of(TENSOR_1_2_FLOAT32))); } /** This test needs a PYTHONPATH that points to `proj7`. */ @@ -1904,7 +2117,7 @@ public void testModule11() "proj7", 1, 1, - Map.of(3, Set.of(MNIST_INPUT))); + Map.of(3, Set.of(TENSOR_1_2_FLOAT32))); } /** This test should not need a PYTHONPATH. */ @@ -1923,7 +2136,7 @@ public void testModule12() "proj8", 1, 1, - Map.of(3, Set.of(MNIST_INPUT))); + Map.of(3, Set.of(TENSOR_1_2_FLOAT32))); } /** This test should not need a PYTHONPATH. */ @@ -1942,7 +2155,7 @@ public void testModule13() "proj9", 1, 1, - Map.of(3, Set.of(MNIST_INPUT))); + Map.of(3, Set.of(TENSOR_1_2_FLOAT32))); } /** @@ -1960,7 +2173,7 @@ public void testModule14() "proj10", 1, 1, - Map.of(2, Set.of(MNIST_INPUT))); + Map.of(2, Set.of(TENSOR_1_2_FLOAT32))); } /** @@ -1978,7 +2191,7 @@ public void testModule15() "proj11", 1, 1, - Map.of(2, Set.of(MNIST_INPUT))); + Map.of(2, Set.of(TENSOR_1_2_FLOAT32))); } /** This test should not need a PYTHONPATH. */ @@ -1992,7 +2205,7 @@ public void testModule16() "proj12", 1, 1, - Map.of(2, Set.of(MNIST_INPUT))); + Map.of(2, Set.of(TENSOR_1_2_FLOAT32))); } /** @@ -2013,7 +2226,7 @@ public void testModule17() "proj13", 1, 1, - Map.of(2, Set.of(MNIST_INPUT))); + Map.of(2, Set.of(TENSOR_1_2_FLOAT32))); } /** @@ -2039,7 +2252,7 @@ public void testModule18() "proj14", 1, 1, - Map.of(2, Set.of(MNIST_INPUT))); + Map.of(2, Set.of(TENSOR_1_2_FLOAT32))); test( new String[] { @@ -2054,7 +2267,7 @@ public void testModule18() "proj14", 1, 1, - Map.of(2, Set.of(MNIST_INPUT))); + Map.of(2, Set.of(TENSOR_1_2_FLOAT32))); } /** @@ -2074,7 +2287,7 @@ public void testModule19() "proj15", 1, 1, - Map.of(2, Set.of(MNIST_INPUT))); + Map.of(2, Set.of(TENSOR_1_2_FLOAT32))); } /** @@ -2092,7 +2305,7 @@ public void testModule20() "proj16", 1, 1, - Map.of(3, Set.of(MNIST_INPUT))); + Map.of(3, Set.of(TENSOR_1_2_FLOAT32))); } /** @@ -2112,7 +2325,7 @@ public void testModule21() "proj17", 1, 1, - Map.of(3, Set.of(MNIST_INPUT))); + Map.of(3, Set.of(TENSOR_1_2_FLOAT32))); } /** @@ -2130,7 +2343,7 @@ public void testModule22() "proj18", 1, 1, - Map.of(2, Set.of(MNIST_INPUT))); + Map.of(2, Set.of(TENSOR_1_2_FLOAT32))); } /** @@ -2154,7 +2367,7 @@ public void testModule23() "proj19", 1, 1, - Map.of(2, Set.of(MNIST_INPUT))); + Map.of(2, Set.of(TENSOR_1_2_FLOAT32))); } /** @@ -2172,7 +2385,7 @@ public void testModule24() "", 1, 1, - Map.of(2, Set.of(MNIST_INPUT))); + Map.of(2, Set.of(TENSOR_1_2_FLOAT32))); } /** @@ -2190,7 +2403,7 @@ public void testModule25() "proj20", 1, 1, - Map.of(3, Set.of(MNIST_INPUT))); + Map.of(3, Set.of(TENSOR_1_2_FLOAT32))); } /** @@ -2208,7 +2421,7 @@ public void testModule26() "", 1, 1, - Map.of(3, Set.of(MNIST_INPUT))); + Map.of(3, Set.of(TENSOR_1_2_FLOAT32))); } /** @@ -2232,7 +2445,7 @@ public void testModule27() "proj21", 1, 1, - Map.of(3, Set.of(MNIST_INPUT))); + Map.of(3, Set.of(TENSOR_1_2_FLOAT32))); test( new String[] { @@ -2247,7 +2460,7 @@ public void testModule27() "proj21", 1, 1, - Map.of(3, Set.of(MNIST_INPUT))); + Map.of(3, Set.of(TENSOR_1_2_FLOAT32))); } /** @@ -2265,7 +2478,7 @@ public void testModule28() "proj22", 1, 1, - Map.of(3, Set.of(MNIST_INPUT))); + Map.of(3, Set.of(TENSOR_1_2_FLOAT32))); } /** @@ -2283,7 +2496,7 @@ public void testModule29() "proj23", 1, 1, - Map.of(2, Set.of(MNIST_INPUT))); + Map.of(2, Set.of(TENSOR_1_2_FLOAT32))); } /** @@ -2301,7 +2514,7 @@ public void testModule30() "proj24", 1, 1, - Map.of(3, Set.of(MNIST_INPUT))); + Map.of(3, Set.of(TENSOR_1_2_FLOAT32))); } /** @@ -2321,7 +2534,7 @@ public void testModule31() "proj25", 1, 1, - Map.of(3, Set.of(MNIST_INPUT))); + Map.of(3, Set.of(TENSOR_1_2_FLOAT32))); } /** @@ -2339,7 +2552,7 @@ public void testModule32() "proj26", 1, 1, - Map.of(3, Set.of(MNIST_INPUT))); + Map.of(3, Set.of(TENSOR_1_2_FLOAT32))); } /** @@ -2359,7 +2572,7 @@ public void testModule33() "proj27", 1, 1, - Map.of(2, Set.of(MNIST_INPUT))); + Map.of(2, Set.of(TENSOR_1_2_FLOAT32))); } /** @@ -2379,7 +2592,7 @@ public void testModule34() "proj28", 1, 1, - Map.of(3, Set.of(MNIST_INPUT))); + Map.of(3, Set.of(TENSOR_1_2_FLOAT32))); } /** @@ -2397,7 +2610,7 @@ public void testModule35() "proj29", 1, 1, - Map.of(2, Set.of(MNIST_INPUT))); + Map.of(2, Set.of(TENSOR_1_2_FLOAT32))); } /** @@ -2415,7 +2628,7 @@ public void testModule36() "proj30", 1, 1, - Map.of(2, Set.of(MNIST_INPUT))); + Map.of(2, Set.of(TENSOR_1_2_FLOAT32))); } /** @@ -2433,7 +2646,7 @@ public void testModule37() "proj31", 1, 1, - Map.of(2, Set.of(MNIST_INPUT))); + Map.of(2, Set.of(TENSOR_1_2_FLOAT32))); } /** @@ -2451,7 +2664,7 @@ public void testModule38() "proj32", 1, 1, - Map.of(2, Set.of(MNIST_INPUT))); + Map.of(2, Set.of(TENSOR_1_2_FLOAT32))); } /** @@ -2469,7 +2682,7 @@ public void testModule39() "proj33", 1, 1, - Map.of(3, Set.of(MNIST_INPUT))); + Map.of(3, Set.of(TENSOR_1_2_FLOAT32))); } /** @@ -2487,7 +2700,7 @@ public void testModule40() "proj34", 1, 1, - Map.of(3, Set.of(MNIST_INPUT))); + Map.of(3, Set.of(TENSOR_1_2_FLOAT32))); } /** @@ -2512,7 +2725,7 @@ public void testModule41() "proj35", 1, 1, - Map.of(2, Set.of(MNIST_INPUT))); + Map.of(2, Set.of(TENSOR_1_2_FLOAT32))); } /** @@ -2537,7 +2750,7 @@ public void testModule42() "proj36", 1, 1, - Map.of(3, Set.of(MNIST_INPUT))); + Map.of(3, Set.of(TENSOR_1_2_FLOAT32))); } /** @@ -2562,7 +2775,7 @@ public void testModule43() "proj37", 1, 1, - Map.of(3, Set.of(MNIST_INPUT))); + Map.of(3, Set.of(TENSOR_1_2_FLOAT32))); } /** @@ -2587,7 +2800,7 @@ public void testModule44() "proj38", 1, 1, - Map.of(2, Set.of(MNIST_INPUT))); + Map.of(2, Set.of(TENSOR_1_2_FLOAT32))); } /** @@ -2605,7 +2818,7 @@ public void testModule45() "proj39", 1, 1, - Map.of(2, Set.of(MNIST_INPUT))); + Map.of(2, Set.of(TENSOR_1_2_FLOAT32))); } /** @@ -2623,7 +2836,7 @@ public void testModule46() "proj40", 1, 1, - Map.of(2, Set.of(MNIST_INPUT))); + Map.of(2, Set.of(TENSOR_1_2_FLOAT32))); } /** @@ -2641,7 +2854,7 @@ public void testModule47() "proj41", 1, 1, - Map.of(3, Set.of(MNIST_INPUT))); + Map.of(3, Set.of(TENSOR_1_2_FLOAT32))); } /** @@ -2659,7 +2872,7 @@ public void testModule48() "proj42", 1, 1, - Map.of(3, Set.of(MNIST_INPUT))); + Map.of(3, Set.of(TENSOR_1_2_FLOAT32))); } /** @@ -2684,7 +2897,7 @@ public void testModule49() "proj43", 1, 1, - Map.of(2, Set.of(MNIST_INPUT))); + Map.of(2, Set.of(TENSOR_1_2_FLOAT32))); } /** @@ -2709,7 +2922,7 @@ public void testModule50() "proj44", 1, 1, - Map.of(3, Set.of(MNIST_INPUT))); + Map.of(3, Set.of(TENSOR_1_2_FLOAT32))); } /** @@ -2734,7 +2947,7 @@ public void testModule51() "proj45", 1, 1, - Map.of(3, Set.of(MNIST_INPUT))); + Map.of(3, Set.of(TENSOR_1_2_FLOAT32))); } /** @@ -2759,7 +2972,7 @@ public void testModule52() "proj46", 1, 1, - Map.of(2, Set.of(MNIST_INPUT))); + Map.of(2, Set.of(TENSOR_1_2_FLOAT32))); } /** @@ -2786,7 +2999,7 @@ public void testModule53() "proj47", 1, 1, - Map.of(2, Set.of(MNIST_INPUT))); + Map.of(2, Set.of(TENSOR_1_2_FLOAT32))); test( new String[] { @@ -2804,7 +3017,7 @@ public void testModule53() "proj47", 1, 1, - Map.of(2, Set.of(MNIST_INPUT))); + Map.of(2, Set.of(TENSOR_1_2_FLOAT32))); } /** Test https://github.com/wala/ML/issues/202. */ @@ -2818,7 +3031,7 @@ public void testModule54() "proj51", 1, 1, - Map.of(2, Set.of(MNIST_INPUT))); + Map.of(2, Set.of(TENSOR_1_2_FLOAT32))); } /** Test https://github.com/wala/ML/issues/202. */ @@ -2832,7 +3045,7 @@ public void testModule55() "proj52", 1, 1, - Map.of(2, Set.of(MNIST_INPUT))); + Map.of(2, Set.of(TENSOR_1_2_FLOAT32))); } /** Test https://github.com/wala/ML/issues/202. */ @@ -2846,7 +3059,7 @@ public void testModule56() "proj53", 1, 1, - Map.of(3, Set.of(MNIST_INPUT))); + Map.of(3, Set.of(TENSOR_1_2_FLOAT32))); } /** Test https://github.com/wala/ML/issues/202. */ @@ -2860,7 +3073,7 @@ public void testModule57() "proj54", 1, 1, - Map.of(3, Set.of(MNIST_INPUT))); + Map.of(3, Set.of(TENSOR_1_2_FLOAT32))); } /** Test https://github.com/wala/ML/issues/202. */ @@ -2874,7 +3087,7 @@ public void testModule58() "proj55", 1, 1, - Map.of(3, Set.of(MNIST_INPUT))); + Map.of(3, Set.of(TENSOR_1_2_FLOAT32))); } /** Test https://github.com/wala/ML/issues/202. */ @@ -2888,7 +3101,7 @@ public void testModule59() "proj51", 1, 1, - Map.of(2, Set.of(MNIST_INPUT))); + Map.of(2, Set.of(TENSOR_1_2_FLOAT32))); } /** Test https://github.com/wala/ML/issues/202. */ @@ -2902,7 +3115,7 @@ public void testModule60() "proj52", 1, 1, - Map.of(2, Set.of(MNIST_INPUT))); + Map.of(2, Set.of(TENSOR_1_2_FLOAT32))); } /** Test https://github.com/wala/ML/issues/202. */ @@ -2916,7 +3129,7 @@ public void testModule61() "proj56", 1, 1, - Map.of(3, Set.of(MNIST_INPUT))); + Map.of(3, Set.of(TENSOR_1_2_FLOAT32))); } /** Test https://github.com/wala/ML/issues/202. */ @@ -2930,7 +3143,7 @@ public void testModule62() "proj57", 1, 1, - Map.of(3, Set.of(MNIST_INPUT))); + Map.of(3, Set.of(TENSOR_1_2_FLOAT32))); } /** Test https://github.com/wala/ML/issues/202. */ @@ -2944,7 +3157,7 @@ public void testModule63() "proj58", 1, 1, - Map.of(3, Set.of(MNIST_INPUT))); + Map.of(3, Set.of(TENSOR_1_2_FLOAT32))); } /** Test https://github.com/wala/ML/issues/202. */ @@ -2958,7 +3171,7 @@ public void testModule64() "proj59", 1, 1, - Map.of(3, Set.of(MNIST_INPUT))); + Map.of(3, Set.of(TENSOR_1_2_FLOAT32))); } /** Test https://github.com/wala/ML/issues/202. */ @@ -2972,7 +3185,7 @@ public void testModule65() "proj60", 1, 1, - Map.of(2, Set.of(MNIST_INPUT))); + Map.of(2, Set.of(TENSOR_1_2_FLOAT32))); } /** Test https://github.com/wala/ML/issues/202. */ @@ -2986,7 +3199,7 @@ public void testModule66() "proj61", 1, 1, - Map.of(2, Set.of(MNIST_INPUT))); + Map.of(2, Set.of(TENSOR_1_2_FLOAT32))); } /** Test https://github.com/wala/ML/issues/202. */ @@ -3000,7 +3213,7 @@ public void testModule67() "proj62", 1, 1, - Map.of(3, Set.of(MNIST_INPUT))); + Map.of(3, Set.of(TENSOR_1_2_FLOAT32))); } /** Test https://github.com/wala/ML/issues/205. */ @@ -3014,7 +3227,7 @@ public void testModule68() "proj63", 1, 1, - Map.of(2, Set.of(MNIST_INPUT))); + Map.of(2, Set.of(TENSOR_1_2_FLOAT32))); } /** Test https://github.com/wala/ML/issues/205. */ @@ -3028,7 +3241,7 @@ public void testModule69() "proj64", 1, 1, - Map.of(2, Set.of(MNIST_INPUT))); + Map.of(2, Set.of(TENSOR_1_2_FLOAT32))); } /** Test https://github.com/wala/ML/issues/210. */ @@ -3042,7 +3255,7 @@ public void testModule70() "proj65", 1, 1, - Map.of(2, Set.of(MNIST_INPUT))); + Map.of(2, Set.of(TENSOR_1_2_FLOAT32))); } /** Test https://github.com/wala/ML/issues/210. */ @@ -3056,7 +3269,7 @@ public void testModule71() "proj67", 1, 1, - Map.of(2, Set.of(MNIST_INPUT))); + Map.of(2, Set.of(TENSOR_1_2_FLOAT32))); } /** Test https://github.com/wala/ML/issues/210. */ @@ -3070,7 +3283,7 @@ public void testModule72() "proj68", 1, 1, - Map.of(2, Set.of(MNIST_INPUT))); + Map.of(2, Set.of(TENSOR_1_2_FLOAT32))); } /** Test https://github.com/wala/ML/issues/210. */ @@ -3084,7 +3297,7 @@ public void testModule73() "proj69", 1, 1, - Map.of(2, Set.of(MNIST_INPUT))); + Map.of(2, Set.of(TENSOR_1_2_FLOAT32))); } /** Test https://github.com/wala/ML/issues/210. */ @@ -3098,7 +3311,7 @@ public void testModule74() "proj70", 1, 1, - Map.of(2, Set.of(MNIST_INPUT))); + Map.of(2, Set.of(TENSOR_1_2_FLOAT32))); } /** Test https://github.com/wala/ML/issues/211. */ @@ -3112,7 +3325,7 @@ public void testModule75() "proj71", 1, 1, - Map.of(2, Set.of(MNIST_INPUT))); + Map.of(2, Set.of(TENSOR_1_2_FLOAT32))); } /** Test https://github.com/wala/ML/issues/211. */ @@ -3126,7 +3339,7 @@ public void testModule76() "", 1, 1, - Map.of(2, Set.of(MNIST_INPUT))); + Map.of(2, Set.of(TENSOR_1_2_FLOAT32))); } /** Test https://github.com/wala/ML/issues/211. */ @@ -3140,7 +3353,7 @@ public void testModule77() "proj72", 1, 1, - Map.of(2, Set.of(MNIST_INPUT))); + Map.of(2, Set.of(TENSOR_1_2_FLOAT32))); } /** Test https://github.com/wala/ML/issues/211. */ @@ -3154,7 +3367,7 @@ public void testModule78() "", 1, 1, - Map.of(2, Set.of(MNIST_INPUT))); + Map.of(2, Set.of(TENSOR_1_2_FLOAT32))); } /** Test https://github.com/wala/ML/issues/209. */ @@ -3174,7 +3387,7 @@ public void testModule79() "proj73", 1, 1, - Map.of(2, Set.of(MNIST_INPUT))); + Map.of(2, Set.of(SCALAR_TENSOR_OF_INT32))); test( new String[] { @@ -3189,7 +3402,7 @@ public void testModule79() "proj73", 1, 1, - Map.of(2, Set.of(MNIST_INPUT))); + Map.of(2, Set.of(SCALAR_TENSOR_OF_INT32))); } /** Test https://github.com/wala/ML/issues/209. */ @@ -3209,7 +3422,7 @@ public void testModule80() "proj74", 1, 1, - Map.of(2, Set.of(MNIST_INPUT))); + Map.of(2, Set.of(SCALAR_TENSOR_OF_INT32))); test( new String[] { @@ -3224,7 +3437,7 @@ public void testModule80() "proj74", 1, 1, - Map.of(2, Set.of(MNIST_INPUT))); + Map.of(2, Set.of(SCALAR_TENSOR_OF_INT32))); } @Test @@ -3234,7 +3447,7 @@ public void testStaticMethod() throws ClassHierarchyException, CancelException, "MyClass.the_static_method", 1, 1, - Map.of(2, Set.of(MNIST_INPUT))); + Map.of(2, Set.of(SCALAR_TENSOR_OF_INT32))); } @Test @@ -3244,7 +3457,7 @@ public void testStaticMethod2() throws ClassHierarchyException, CancelException, "MyClass.the_static_method", 1, 1, - Map.of(2, Set.of(MNIST_INPUT))); + Map.of(2, Set.of(SCALAR_TENSOR_OF_INT32))); } @Test @@ -3254,7 +3467,7 @@ public void testStaticMethod3() throws ClassHierarchyException, CancelException, "MyClass.the_static_method", 1, 1, - Map.of(2, Set.of(MNIST_INPUT))); + Map.of(2, Set.of(SCALAR_TENSOR_OF_INT32))); } @Test @@ -3264,7 +3477,7 @@ public void testStaticMethod4() throws ClassHierarchyException, CancelException, "MyClass.the_static_method", 1, 1, - Map.of(2, Set.of(MNIST_INPUT))); + Map.of(2, Set.of(SCALAR_TENSOR_OF_INT32))); } @Test @@ -3274,7 +3487,7 @@ public void testStaticMethod5() throws ClassHierarchyException, CancelException, "MyClass.the_static_method", 1, 1, - Map.of(2, Set.of(MNIST_INPUT))); + Map.of(2, Set.of(SCALAR_TENSOR_OF_INT32))); } @Test @@ -3284,7 +3497,7 @@ public void testStaticMethod6() throws ClassHierarchyException, CancelException, "MyClass.the_static_method", 1, 1, - Map.of(2, Set.of(MNIST_INPUT))); + Map.of(2, Set.of(SCALAR_TENSOR_OF_INT32))); } @Test @@ -3294,7 +3507,7 @@ public void testStaticMethod7() throws ClassHierarchyException, CancelException, "MyClass.the_static_method", 1, 1, - Map.of(3, Set.of(MNIST_INPUT))); + Map.of(3, Set.of(SCALAR_TENSOR_OF_INT32))); } @Test @@ -3304,7 +3517,7 @@ public void testStaticMethod8() throws ClassHierarchyException, CancelException, "MyClass.the_static_method", 1, 1, - Map.of(3, Set.of(MNIST_INPUT))); + Map.of(3, Set.of(SCALAR_TENSOR_OF_INT32))); } @Test @@ -3314,7 +3527,7 @@ public void testStaticMethod9() throws ClassHierarchyException, CancelException, "MyClass.the_static_method", 2, 2, - Map.of(2, Set.of(MNIST_INPUT), 3, Set.of(MNIST_INPUT))); + Map.of(2, Set.of(SCALAR_TENSOR_OF_INT32), 3, Set.of(SCALAR_TENSOR_OF_INT32))); } @Test @@ -3324,17 +3537,39 @@ public void testStaticMethod10() throws ClassHierarchyException, CancelException "MyClass.the_static_method", 2, 2, - Map.of(2, Set.of(MNIST_INPUT), 3, Set.of(MNIST_INPUT))); + Map.of(2, Set.of(SCALAR_TENSOR_OF_INT32), 3, Set.of(SCALAR_TENSOR_OF_INT32))); } @Test public void testStaticMethod11() throws ClassHierarchyException, CancelException, IOException { - test("tf2_test_static_method11.py", "f", 1, 1, Map.of(2, Set.of(MNIST_INPUT))); + test("tf2_test_static_method11.py", "f", 1, 1, Map.of(2, Set.of(SCALAR_TENSOR_OF_INT32))); } @Test public void testStaticMethod12() throws ClassHierarchyException, CancelException, IOException { - test("tf2_test_static_method12.py", "f", 1, 1, Map.of(2, Set.of(MNIST_INPUT))); + test("tf2_test_static_method12.py", "f", 1, 1, Map.of(2, Set.of(SCALAR_TENSOR_OF_INT32))); + } + + @Test(expected = IllegalStateException.class) + public void testStaticMethod13() throws ClassHierarchyException, CancelException, IOException { + // NOTE: This test will no longer throw an exception once data types other than lists are + // supported for shape arguments. + test( + "tf2_test_static_method13.py", + "MyClass.the_static_method", + 1, + 1, + Map.of(2, Set.of(TENSOR_5_FLOAT32))); + } + + @Test + public void testStaticMethod14() throws ClassHierarchyException, CancelException, IOException { + test( + "tf2_test_static_method14.py", + "MyClass.the_static_method", + 1, + 1, + Map.of(2, Set.of(TENSOR_1_2_FLOAT32))); } @Test @@ -3344,7 +3579,7 @@ public void testClassMethod() throws ClassHierarchyException, CancelException, I "MyClass.the_class_method", 1, 1, - Map.of(3, Set.of(MNIST_INPUT))); + Map.of(3, Set.of(SCALAR_TENSOR_OF_INT32))); } @Test @@ -3354,44 +3589,44 @@ public void testClassMethod2() throws ClassHierarchyException, CancelException, "MyClass.the_class_method", 1, 1, - Map.of(3, Set.of(MNIST_INPUT))); + Map.of(3, Set.of(SCALAR_TENSOR_OF_INT32))); } @Test public void testClassMethod3() throws ClassHierarchyException, CancelException, IOException { - test("tf2_test_class_method3.py", "MyClass.f", 1, 1, Map.of(2, Set.of(MNIST_INPUT))); + test("tf2_test_class_method3.py", "MyClass.f", 1, 1, Map.of(2, Set.of(SCALAR_TENSOR_OF_INT32))); } @Test public void testClassMethod4() throws ClassHierarchyException, CancelException, IOException { - test("tf2_test_class_method4.py", "MyClass.f", 1, 1, Map.of(2, Set.of(MNIST_INPUT))); + test("tf2_test_class_method4.py", "MyClass.f", 1, 1, Map.of(2, Set.of(SCALAR_TENSOR_OF_INT32))); } @Test public void testClassMethod5() throws ClassHierarchyException, CancelException, IOException { - test("tf2_test_class_method5.py", "MyClass.f", 1, 1, Map.of(2, Set.of(MNIST_INPUT))); + test("tf2_test_class_method5.py", "MyClass.f", 1, 1, Map.of(2, Set.of(SCALAR_TENSOR_OF_INT32))); } @Test public void testAbstractMethod() throws ClassHierarchyException, CancelException, IOException { - test("tf2_test_abstract_method.py", "D.f", 1, 1, Map.of(3, Set.of(MNIST_INPUT))); - test("tf2_test_abstract_method.py", "C.f", 1, 1, Map.of(3, Set.of(MNIST_INPUT))); + test("tf2_test_abstract_method.py", "D.f", 1, 1, Map.of(3, Set.of(SCALAR_TENSOR_OF_INT32))); + test("tf2_test_abstract_method.py", "C.f", 1, 1, Map.of(3, Set.of(SCALAR_TENSOR_OF_INT32))); } @Test public void testAbstractMethod2() throws ClassHierarchyException, CancelException, IOException { - test("tf2_test_abstract_method2.py", "D.f", 1, 1, Map.of(3, Set.of(MNIST_INPUT))); - test("tf2_test_abstract_method2.py", "C.f", 1, 1, Map.of(3, Set.of(MNIST_INPUT))); + test("tf2_test_abstract_method2.py", "D.f", 1, 1, Map.of(3, Set.of(SCALAR_TENSOR_OF_INT32))); + test("tf2_test_abstract_method2.py", "C.f", 1, 1, Map.of(3, Set.of(SCALAR_TENSOR_OF_INT32))); } @Test public void testAbstractMethod3() throws ClassHierarchyException, CancelException, IOException { - test("tf2_test_abstract_method3.py", "C.f", 1, 1, Map.of(3, Set.of(MNIST_INPUT))); + test("tf2_test_abstract_method3.py", "C.f", 1, 1, Map.of(3, Set.of(SCALAR_TENSOR_OF_INT32))); } @Test public void testDecoratedMethod() throws ClassHierarchyException, CancelException, IOException { - test("tf2_test_decorated_method.py", "f", 1, 1, Map.of(2, Set.of(MNIST_INPUT))); + test("tf2_test_decorated_method.py", "f", 1, 1, Map.of(2, Set.of(SCALAR_TENSOR_OF_INT32))); } /** Test https://github.com/wala/ML/issues/188. */ @@ -3408,27 +3643,27 @@ public void testDecoratedMethod3() throws ClassHierarchyException, CancelExcepti @Test public void testDecoratedMethod4() throws ClassHierarchyException, CancelException, IOException { - test("tf2_test_decorated_method4.py", "raffi", 1, 1, Map.of(2, Set.of(MNIST_INPUT))); + test("tf2_test_decorated_method4.py", "raffi", 1, 1, Map.of(2, Set.of(SCALAR_TENSOR_OF_INT32))); } @Test public void testDecoratedMethod5() throws ClassHierarchyException, CancelException, IOException { - test("tf2_test_decorated_method5.py", "raffi", 1, 1, Map.of(2, Set.of(MNIST_INPUT))); + test("tf2_test_decorated_method5.py", "raffi", 1, 1, Map.of(2, Set.of(SCALAR_TENSOR_OF_INT32))); } @Test public void testDecoratedMethod6() throws ClassHierarchyException, CancelException, IOException { - test("tf2_test_decorated_method6.py", "f", 1, 1, Map.of(2, Set.of(MNIST_INPUT))); + test("tf2_test_decorated_method6.py", "f", 1, 1, Map.of(2, Set.of(SCALAR_TENSOR_OF_INT32))); } @Test public void testDecoratedMethod7() throws ClassHierarchyException, CancelException, IOException { - test("tf2_test_decorated_method7.py", "f", 1, 1, Map.of(2, Set.of(MNIST_INPUT))); + test("tf2_test_decorated_method7.py", "f", 1, 1, Map.of(2, Set.of(SCALAR_TENSOR_OF_INT32))); } @Test public void testDecoratedMethod8() throws ClassHierarchyException, CancelException, IOException { - test("tf2_test_decorated_method8.py", "f", 1, 1, Map.of(2, Set.of(MNIST_INPUT))); + test("tf2_test_decorated_method8.py", "f", 1, 1, Map.of(2, Set.of(SCALAR_TENSOR_OF_INT32))); } /** @@ -3450,7 +3685,7 @@ public void testDecoratedMethod10() throws ClassHierarchyException, CancelExcept @Test public void testDecoratedMethod11() throws ClassHierarchyException, CancelException, IOException { - test("tf2_test_decorated_method11.py", "f", 1, 1, Map.of(2, Set.of(MNIST_INPUT))); + test("tf2_test_decorated_method11.py", "f", 1, 1, Map.of(2, Set.of(SCALAR_TENSOR_OF_INT32))); } @Test @@ -3468,12 +3703,42 @@ public void testDecoratedMethod13() throws ClassHierarchyException, CancelExcept @Test public void testDecoratedFunctions() throws ClassHierarchyException, CancelException, IOException { - test("tf2_test_decorated_functions.py", "dummy_fun", 1, 1, Map.of(2, Set.of(MNIST_INPUT))); - test("tf2_test_decorated_functions.py", "dummy_test", 1, 1, Map.of(2, Set.of(MNIST_INPUT))); - test("tf2_test_decorated_functions.py", "test_function", 1, 1, Map.of(2, Set.of(MNIST_INPUT))); - test("tf2_test_decorated_functions.py", "test_function2", 1, 1, Map.of(2, Set.of(MNIST_INPUT))); - test("tf2_test_decorated_functions.py", "test_function3", 1, 1, Map.of(2, Set.of(MNIST_INPUT))); - test("tf2_test_decorated_functions.py", "test_function4", 1, 1, Map.of(2, Set.of(MNIST_INPUT))); + test( + "tf2_test_decorated_functions.py", + "dummy_fun", + 1, + 1, + Map.of(2, Set.of(SCALAR_TENSOR_OF_INT32))); + test( + "tf2_test_decorated_functions.py", + "dummy_test", + 1, + 1, + Map.of(2, Set.of(SCALAR_TENSOR_OF_INT32))); + test( + "tf2_test_decorated_functions.py", + "test_function", + 1, + 1, + Map.of(2, Set.of(SCALAR_TENSOR_OF_INT32))); + test( + "tf2_test_decorated_functions.py", + "test_function2", + 1, + 1, + Map.of(2, Set.of(SCALAR_TENSOR_OF_INT32))); + test( + "tf2_test_decorated_functions.py", + "test_function3", + 1, + 1, + Map.of(2, Set.of(SCALAR_TENSOR_OF_INT32))); + test( + "tf2_test_decorated_functions.py", + "test_function4", + 1, + 1, + Map.of(2, Set.of(SCALAR_TENSOR_OF_INT32))); } /** Test a pytest with decorators. */ @@ -3504,21 +3769,21 @@ public void testDecoratedFunctions3() "proj48", 1, 1, - Map.of(3, Set.of(MNIST_INPUT))); + Map.of(3, Set.of(TENSOR_1_2_FLOAT32))); } /** Test a pytest without decorators. This is a "control." */ @Test public void testDecoratedFunctions4() throws ClassHierarchyException, CancelException, IOException { - test("test_decorated_functions2.py", "f", 1, 1, Map.of(2, Set.of(MNIST_INPUT))); + test("test_decorated_functions2.py", "f", 1, 1, Map.of(2, Set.of(SCALAR_TENSOR_OF_INT32))); } /** Test a pytest with a decorator. */ @Test public void testDecoratedFunctions5() throws ClassHierarchyException, CancelException, IOException { - test("test_decorated_functions3.py", "f", 1, 1, Map.of(2, Set.of(MNIST_INPUT))); + test("test_decorated_functions3.py", "f", 1, 1, Map.of(2, Set.of(SCALAR_TENSOR_OF_INT32))); } /** @@ -3541,14 +3806,14 @@ public void testDecoratedFunctions6() "proj49", 1, 1, - Map.of(3, Set.of(MNIST_INPUT))); + Map.of(3, Set.of(TENSOR_1_2_FLOAT32))); } /** Test a Pytest with a decorator without parameters. */ @Test public void testDecoratedFunctions7() throws ClassHierarchyException, CancelException, IOException { - test("test_decorated_functions4.py", "f", 1, 1, Map.of(2, Set.of(MNIST_INPUT))); + test("test_decorated_functions4.py", "f", 1, 1, Map.of(2, Set.of(SCALAR_TENSOR_OF_INT32))); } /** @@ -3571,7 +3836,7 @@ public void testDecoratedFunctions8() "proj50", 1, 1, - Map.of(3, Set.of(MNIST_INPUT))); + Map.of(3, Set.of(TENSOR_1_2_FLOAT32))); } /** @@ -3580,7 +3845,7 @@ public void testDecoratedFunctions8() @Test public void testDecoratedFunctions9() throws ClassHierarchyException, CancelException, IOException { - test("decorated_function_test.py", "f", 1, 1, Map.of(2, Set.of(MNIST_INPUT))); + test("decorated_function_test.py", "f", 1, 1, Map.of(2, Set.of(SCALAR_TENSOR_OF_INT32))); } /** Test https://github.com/wala/ML/issues/195. */ diff --git a/com.ibm.wala.cast.python.ml/data/tensorflow.xml b/com.ibm.wala.cast.python.ml/data/tensorflow.xml index da59452d1..36e3702f6 100644 --- a/com.ibm.wala.cast.python.ml/data/tensorflow.xml +++ b/com.ibm.wala.cast.python.ml/data/tensorflow.xml @@ -154,6 +154,9 @@ + + + @@ -218,6 +221,10 @@ + + + + @@ -315,6 +322,17 @@ + + + + + + + + + + + diff --git a/com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/analysis/TensorTypeAnalysis.java b/com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/analysis/TensorTypeAnalysis.java index ee2a41b90..bc748c1b1 100644 --- a/com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/analysis/TensorTypeAnalysis.java +++ b/com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/analysis/TensorTypeAnalysis.java @@ -447,11 +447,11 @@ public String toString() { }; } - private final Map init; + private final Map> init; public TensorTypeAnalysis( Graph G, - Map init, + Map> init, Map reshapeTypes, Map set_shapes, Set conv2ds, @@ -480,7 +480,8 @@ protected TensorVariable[] makeStmtRHS(int size) { protected void initializeVariables() { super.initializeVariables(); for (PointsToSetVariable src : init.keySet()) { - getOut(src).state.add(init.get(src)); + Set tensorTypes = init.get(src); + getOut(src).state.addAll(tensorTypes); } } diff --git a/com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/client/Constant.java b/com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/client/Constant.java new file mode 100644 index 000000000..b034bfce8 --- /dev/null +++ b/com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/client/Constant.java @@ -0,0 +1,60 @@ +package com.ibm.wala.cast.python.ml.client; + +import com.ibm.wala.cast.python.ml.types.TensorFlowTypes.DType; +import com.ibm.wala.cast.python.ml.types.TensorType.Dimension; +import com.ibm.wala.ipa.callgraph.CGNode; +import com.ibm.wala.ipa.callgraph.propagation.PointsToSetVariable; +import com.ibm.wala.ipa.callgraph.propagation.PropagationCallGraphBuilder; +import java.util.EnumSet; +import java.util.List; +import java.util.Set; + +/** + * Represents a call to the constant() function in TensorFlow. + * + * @see constant(). + * @author Raffi Khatchadourian + */ +public class Constant extends TensorGenerator { + + private static final int VALUE_NUMBER_FOR_VALUE_ARGUMENT = 2; + + private static final int VALUE_NUMBER_FOR_DTYPE_ARGUMENT = 3; + + private static final int VALUE_NUMBER_FOR_SHAPE_ARGUMENT = 4; + + public Constant(PointsToSetVariable source, CGNode node) { + super(source, node); + } + + @Override + protected Set>> getDefaultShapes(PropagationCallGraphBuilder builder) { + // If the shape argument is not specified, then the shape is inferred from the shape of value. + // TODO: Handle keyword arguments. + return getShapes(builder, this.getValueNumberForValueArgument()); + } + + @Override + protected EnumSet getDefaultDTypes(PropagationCallGraphBuilder builder) { + // If the dtype argument is not specified, then the type is inferred from the type of value. + // TODO: Handle keyword arguments. + return getDTypes(builder, this.getValueNumberForValueArgument()); + } + + @Override + protected int getValueNumberForDTypeArgument() { + return VALUE_NUMBER_FOR_DTYPE_ARGUMENT; + } + + protected int getValueNumberForValueArgument() { + return VALUE_NUMBER_FOR_VALUE_ARGUMENT; + } + + @Override + protected int getValueNumberForShapeArgument() { + // Shapes can also be specified as an explicit argument. Here, we examine the third explicit + // argument (recall that the first argument is implicit and corresponds to the called + // function's name). + return VALUE_NUMBER_FOR_SHAPE_ARGUMENT; + } +} diff --git a/com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/client/Ones.java b/com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/client/Ones.java new file mode 100644 index 000000000..c6655d972 --- /dev/null +++ b/com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/client/Ones.java @@ -0,0 +1,54 @@ +package com.ibm.wala.cast.python.ml.client; + +import static com.ibm.wala.cast.python.ml.types.TensorFlowTypes.DType.FLOAT32; + +import com.ibm.wala.cast.python.ml.types.TensorFlowTypes.DType; +import com.ibm.wala.cast.python.ml.types.TensorType.Dimension; +import com.ibm.wala.ipa.callgraph.CGNode; +import com.ibm.wala.ipa.callgraph.propagation.PointsToSetVariable; +import com.ibm.wala.ipa.callgraph.propagation.PropagationCallGraphBuilder; +import java.util.EnumSet; +import java.util.List; +import java.util.Set; + +/** + * A generator for tensors created by the `ones()` function in TensorFlow. + * + * @see TensorFlow ones() API. + * @author Raffi Khatchadourian + */ +public class Ones extends TensorGenerator { + + private static final int VALUE_NUMBER_FOR_SHAPE_ARGUMENT = 2; + + private static final int VALUE_NUMBER_FOR_DTYPE_ARGUMENT = 3; + + public Ones(PointsToSetVariable source, CGNode node) { + super(source, node); + } + + @Override + protected EnumSet getDefaultDTypes(PropagationCallGraphBuilder builder) { + LOGGER.info( + "No dtype specified for source: " + source + ". Using default dtype of: " + FLOAT32 + " ."); + + // Use the default dtype of float32. + return EnumSet.of(FLOAT32); + } + + @Override + protected Set>> getDefaultShapes(PropagationCallGraphBuilder builder) { + throw new UnsupportedOperationException( + "Shapes for ones() are mandatory and must be provided explicitly."); + } + + @Override + protected int getValueNumberForShapeArgument() { + return VALUE_NUMBER_FOR_SHAPE_ARGUMENT; // The shape is in the first explicit argument. + } + + @Override + protected int getValueNumberForDTypeArgument() { + return VALUE_NUMBER_FOR_DTYPE_ARGUMENT; // The dtype is in the second explicit argument. + } +} diff --git a/com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/client/PythonTensorAnalysisEngine.java b/com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/client/PythonTensorAnalysisEngine.java index 710b3b44a..1bd29592a 100644 --- a/com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/client/PythonTensorAnalysisEngine.java +++ b/com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/client/PythonTensorAnalysisEngine.java @@ -667,10 +667,9 @@ public TensorTypeAnalysis performAnalysis(PropagationCallGraphBuilder builder) Set sources = getDataflowSources(dataflow, builder.getCallGraph(), builder.getPointerAnalysis()); - TensorType mnistData = TensorType.mnistInput(); - Map init = HashMapFactory.make(); + Map> init = HashMapFactory.make(); - for (PointsToSetVariable v : sources) init.put(v, mnistData); + for (PointsToSetVariable v : sources) init.put(v, getTensorTypes(v, builder)); Map placeholders = null; try { @@ -681,7 +680,7 @@ public TensorTypeAnalysis performAnalysis(PropagationCallGraphBuilder builder) logger.fine("Placeholders: " + placeholders); for (Map.Entry e : placeholders.entrySet()) - init.put(e.getKey(), e.getValue()); + init.put(e.getKey(), Set.of(e.getValue())); Map setCalls = HashMapFactory.make(); Map set_shapes = getShapeSourceCalls(set_shape, builder, 1); @@ -722,6 +721,24 @@ public TensorTypeAnalysis performAnalysis(PropagationCallGraphBuilder builder) return tt; } + /** + * Returns the set of possible {@link TensorType}s that the given {@link PointsToSetVariable} can + * take on. + * + * @param source The dataflow source to analyze. + * @param builder The {@link PropagationCallGraphBuilder} used to build the call graph and pointer + * analysis. + * @return A set of {@link TensorType}s that the given {@link PointsToSetVariable} can take on. + * Empty set is returned if the possible tensor types cannot be determined. + */ + private Set getTensorTypes( + PointsToSetVariable source, PropagationCallGraphBuilder builder) { + logger.info("Getting tensor types for source: " + source + "."); + + TensorGenerator generator = TensorGeneratorFactory.getGenerator(source); + return generator.getTensorTypes(builder); + } + private Map handleShapeSourceOp( PropagationCallGraphBuilder builder, Graph dataflow, diff --git a/com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/client/Range.java b/com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/client/Range.java new file mode 100644 index 000000000..3c6647827 --- /dev/null +++ b/com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/client/Range.java @@ -0,0 +1,166 @@ +package com.ibm.wala.cast.python.ml.client; + +import static java.util.function.Function.identity; + +import com.ibm.wala.cast.python.ml.types.TensorFlowTypes.DType; +import com.ibm.wala.cast.python.ml.types.TensorType.Dimension; +import com.ibm.wala.cast.python.ml.types.TensorType.NumericDim; +import com.ibm.wala.ipa.callgraph.CGNode; +import com.ibm.wala.ipa.callgraph.propagation.ConstantKey; +import com.ibm.wala.ipa.callgraph.propagation.InstanceKey; +import com.ibm.wala.ipa.callgraph.propagation.PointerAnalysis; +import com.ibm.wala.ipa.callgraph.propagation.PointerKey; +import com.ibm.wala.ipa.callgraph.propagation.PointsToSetVariable; +import com.ibm.wala.ipa.callgraph.propagation.PropagationCallGraphBuilder; +import com.ibm.wala.util.collections.HashSetFactory; +import com.ibm.wala.util.debug.UnimplementedError; +import com.ibm.wala.util.intset.OrdinalSet; +import java.util.EnumSet; +import java.util.List; +import java.util.Set; +import java.util.stream.Collectors; +import java.util.stream.IntStream; +import java.util.stream.StreamSupport; + +/** + * A representation of the TensorFlow range operation. + * + *

This class is used to generate a tensor that contains a sequence of numbers, similar to the + * range function in Python. + * + * @see TensorFlow range + * documentation. + * @author Raffi Khatchadourian + */ +public class Range extends TensorGenerator { + + public Range(PointsToSetVariable source, CGNode node) { + super(source, node); + } + + @Override + protected Set>> getShapes(PropagationCallGraphBuilder builder) { + Set>> ret = HashSetFactory.make(); + PointerAnalysis pointerAnalysis = builder.getPointerAnalysis(); + + // The shape of a range tensor is always a 1D tensor with the length equal to the number of + // elements in the range. + // For example, `tf.range(5)` produces a tensor with shape (5,). + + double start = 0; // Default start value. + double limit = start; // Default limit value. + double delta = 1; // Default step value. + + // There are two versions of the `range` function: + // 1. `tf.range(limit)` - generates a range from 0 to limit + // 2. `tf.range(start, limit, delta)` - generates a range from start to limit with a step of + // delta. + + // First, decide which version of the `range` function is being called based on the number of + // numeric arguments.j + // TODO: Handle keyword arguments. + + int numOfNumericPositionalArgs = getNumberOfNumericPositionalArgs(pointerAnalysis); + + if (numOfNumericPositionalArgs == 1) { + // it must *just* be `limit`. + PointerKey limitPK = pointerAnalysis.getHeapModel().getPointerKeyForLocal(node, 2); + OrdinalSet limitPointsToSet = pointerAnalysis.getPointsToSet(limitPK); + + assert !limitPointsToSet.isEmpty() : "Expected a non-empty points-to set for limit."; + + for (InstanceKey limitIK : limitPointsToSet) + if (limitIK instanceof ConstantKey) { + limit = ((Number) ((ConstantKey) limitIK).getValue()).doubleValue(); + int shape = (int) Math.ceil((limit - start) / delta); + ret.add(List.of(new NumericDim(shape))); // Add the shape as a 1D tensor. + } else + throw new IllegalStateException( + "Expected a " + ConstantKey.class + " for limit, but got: " + limitIK + "."); + } else + // TODO: Handle more cases. + throw new UnimplementedError( + "Currently cannot handle more than one numeric positional argument for range()."); + + return ret; + } + + private int getNumberOfNumericPositionalArgs(PointerAnalysis pointerAnalysis) { + int ret = 0; + int explicitArgumentIndex = 2; // Start from the first explicit argument. + + while (true) { + PointerKey pk = + pointerAnalysis.getHeapModel().getPointerKeyForLocal(node, explicitArgumentIndex); + OrdinalSet pointsToSet = pointerAnalysis.getPointsToSet(pk); + + if (pointsToSet.isEmpty()) break; // End of positional arguments. + + // Check if the pointsToSet contains numeric values. + boolean allNumeric = + StreamSupport.stream(pointsToSet.spliterator(), false) + .filter(ik -> ik instanceof ConstantKey) + .map(ik -> (ConstantKey) ik) + .map(ConstantKey::getValue) + .allMatch(v -> v instanceof Number); // Check if all values are numeric. + + if (!allNumeric) break; // There's some argument that is not numeric for this argument. + + ret++; // Increment the count of numeric positional arguments. + explicitArgumentIndex++; // Move to the next explicit argument. + } + + return ret; + } + + @Override + protected EnumSet getDefaultDTypes(PropagationCallGraphBuilder builder) { + // The dtype of the resulting tensor is inferred from the inputs unless it is provided + // explicitly. + + // TODO: Handle keyword arguments. + int numberOfNumericPositionalArgs = + getNumberOfNumericPositionalArgs(builder.getPointerAnalysis()); + + EnumSet types = + IntStream.range(0, numberOfNumericPositionalArgs) + .map(i -> i + 2) // Positional arguments start at index 2. + .mapToObj(val -> getDTypes(builder, val).stream()) + .flatMap(identity()) + .distinct() + .collect(Collectors.toCollection(() -> EnumSet.noneOf(DType.class))); + + // FIXME: We can't tell the difference here between varying dtypes in a single call and that of + // possible varying dtypes values from the points-to graph. Below, we are treating it as these + // values lie in a single call, but that may not be the case. + + if (types.contains(DType.FLOAT64)) return EnumSet.of(DType.FLOAT64); + else if (types.contains(DType.FLOAT32)) return EnumSet.of(DType.FLOAT32); + else if (types.contains(DType.INT64)) return EnumSet.of(DType.INT64); + else if (types.contains(DType.INT32)) return EnumSet.of(DType.INT32); + + throw new IllegalStateException( + "Expected at least one numeric dtype for range(), but got: " + types + "."); + } + + @Override + protected Set>> getDefaultShapes(PropagationCallGraphBuilder builder) { + throw new UnsupportedOperationException( + "Shapes for range() are derived from mandatory numeric arguments and must be provided" + + " explicitly."); + } + + @Override + protected int getValueNumberForShapeArgument() { + throw new UnsupportedOperationException( + "Range does not have a shape argument. Its shape is derived from the numeric arguments."); + } + + @Override + protected int getValueNumberForDTypeArgument() { + // TODO: We need a value number for the dtype argument. Also, that value number can differ + // depending on the version of the `range` function being called. + + return -1; // Positional dtype argument for range() is not yet implemented. + } +} diff --git a/com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/client/TensorGenerator.java b/com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/client/TensorGenerator.java new file mode 100644 index 000000000..5975ecb7d --- /dev/null +++ b/com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/client/TensorGenerator.java @@ -0,0 +1,560 @@ +package com.ibm.wala.cast.python.ml.client; + +import static com.ibm.wala.cast.python.ml.types.TensorFlowTypes.DType.FLOAT32; +import static com.ibm.wala.cast.python.ml.types.TensorFlowTypes.DType.INT32; +import static com.ibm.wala.cast.python.ml.types.TensorFlowTypes.DType.STRING; +import static com.ibm.wala.cast.python.ml.types.TensorFlowTypes.D_TYPE; +import static com.ibm.wala.cast.python.ml.types.TensorFlowTypes.TENSORFLOW; +import static com.ibm.wala.cast.python.types.PythonTypes.list; +import static com.ibm.wala.cast.python.util.Util.getAllocationSiteInNode; +import static com.ibm.wala.core.util.strings.Atom.findOrCreateAsciiAtom; +import static com.ibm.wala.ipa.callgraph.propagation.cfa.CallStringContextSelector.CALL_STRING; +import static java.util.Arrays.asList; +import static java.util.Collections.emptyList; + +import com.ibm.wala.cast.ipa.callgraph.AstPointerKeyFactory; +import com.ibm.wala.cast.python.ml.types.TensorFlowTypes; +import com.ibm.wala.cast.python.ml.types.TensorFlowTypes.DType; +import com.ibm.wala.cast.python.ml.types.TensorType; +import com.ibm.wala.cast.python.ml.types.TensorType.Dimension; +import com.ibm.wala.cast.python.ml.types.TensorType.NumericDim; +import com.ibm.wala.cast.python.types.PythonTypes; +import com.ibm.wala.classLoader.IClass; +import com.ibm.wala.classLoader.IField; +import com.ibm.wala.classLoader.IMethod; +import com.ibm.wala.classLoader.NewSiteReference; +import com.ibm.wala.ipa.callgraph.CGNode; +import com.ibm.wala.ipa.callgraph.ContextItem; +import com.ibm.wala.ipa.callgraph.propagation.AllocationSiteInNode; +import com.ibm.wala.ipa.callgraph.propagation.ConstantKey; +import com.ibm.wala.ipa.callgraph.propagation.InstanceKey; +import com.ibm.wala.ipa.callgraph.propagation.PointerAnalysis; +import com.ibm.wala.ipa.callgraph.propagation.PointerKey; +import com.ibm.wala.ipa.callgraph.propagation.PointsToSetVariable; +import com.ibm.wala.ipa.callgraph.propagation.PropagationCallGraphBuilder; +import com.ibm.wala.ipa.callgraph.propagation.cfa.CallString; +import com.ibm.wala.types.Descriptor; +import com.ibm.wala.types.FieldReference; +import com.ibm.wala.types.MethodReference; +import com.ibm.wala.types.TypeReference; +import com.ibm.wala.util.collections.HashSetFactory; +import com.ibm.wala.util.intset.OrdinalSet; +import java.util.ArrayList; +import java.util.EnumSet; +import java.util.List; +import java.util.Optional; +import java.util.Set; +import java.util.logging.Logger; + +public abstract class TensorGenerator { + + protected static final Logger LOGGER = Logger.getLogger(TensorGenerator.class.getName()); + + private static final MethodReference IMPORT = + MethodReference.findOrCreate( + TENSORFLOW, + findOrCreateAsciiAtom("import"), + Descriptor.findOrCreate(null, TENSORFLOW.getName())); + + protected PointsToSetVariable source; + + protected CGNode node; + + public TensorGenerator(PointsToSetVariable source, CGNode node) { + this.source = source; + this.node = node; + } + + public Set getTensorTypes(PropagationCallGraphBuilder builder) { + Set>> shapes = getShapes(builder); + EnumSet dTypes = getDTypes(builder); + + Set ret = HashSetFactory.make(); + + // Create a tensor type for each possible shape and dtype combination. + for (List> dimensionList : shapes) + for (DType dtype : dTypes) ret.add(new TensorType(dtype.name().toLowerCase(), dimensionList)); + + return ret; + } + + /** + * Returns the possible shapes of the tensor returned by this generator. + * + * @param builder The {@link PropagationCallGraphBuilder} used to build the call graph. + * @param pointsToSet The points-to set of the shape argument. + * @return A set of possible shapes of the tensor returned by this generator. + */ + protected Set>> getShapesFromShapeArgument( + PropagationCallGraphBuilder builder, Iterable pointsToSet) { + Set>> ret = HashSetFactory.make(); + PointerAnalysis pointerAnalysis = builder.getPointerAnalysis(); + + for (InstanceKey instanceKey : pointsToSet) { + AllocationSiteInNode asin = getAllocationSiteInNode(instanceKey); + TypeReference reference = asin.getConcreteType().getReference(); + + if (reference.equals(list)) { // TODO: This can also be a tuple of tensors. + // We have a list of integers that represent the shape. + OrdinalSet objectCatalogPointsToSet = + pointerAnalysis.getPointsToSet( + ((AstPointerKeyFactory) builder.getPointerKeyFactory()) + .getPointerKeyForObjectCatalog(asin)); + + // We expect the object catalog to contain a list of integers. Each element in the array + // corresponds to the set of possible dimensions for that index. + @SuppressWarnings("unchecked") + Set>[] possibleDimensions = new Set[objectCatalogPointsToSet.size()]; + + for (InstanceKey catalogIK : objectCatalogPointsToSet) { + ConstantKey constantKey = (ConstantKey) catalogIK; + Object constantKeyValue = constantKey.getValue(); + + Integer fieldIndex = (Integer) constantKeyValue; + + FieldReference subscript = + FieldReference.findOrCreate( + PythonTypes.Root, findOrCreateAsciiAtom(fieldIndex.toString()), PythonTypes.Root); + + IField f = builder.getClassHierarchy().resolveField(subscript); + LOGGER.fine("Found field: " + f); + + // We can now get the pointer key for the instance field. + PointerKey pointerKeyForInstanceField = builder.getPointerKeyForInstanceField(asin, f); + LOGGER.fine("Found pointer key for instance field: " + pointerKeyForInstanceField + "."); + + // Get the points-to set for the instance field. + OrdinalSet instanceFieldPointsToSet = + pointerAnalysis.getPointsToSet(pointerKeyForInstanceField); + LOGGER.fine("Points-to set for instance field: " + instanceFieldPointsToSet + "."); + + // If the instance field points to a constant, we can use it as the shape. + // TODO: Is it possible to also do it for (simple) expressions? + Set> tensorDimensions = HashSetFactory.make(); + + for (InstanceKey instanceFieldIK : instanceFieldPointsToSet) { + if (instanceFieldIK instanceof ConstantKey) { + // We have a constant key. + ConstantKey instanceFieldConstant = (ConstantKey) instanceFieldIK; + Object instanceFieldValue = instanceFieldConstant.getValue(); + + // We have a shape value. + Long shapeValue = (Long) instanceFieldValue; + LOGGER.fine( + "Found shape value: " + shapeValue + " for " + source.getPointerKey() + "."); + + Dimension dimension = new NumericDim(shapeValue.intValue()); + + LOGGER.fine("Adding dimension: " + dimension + "."); + tensorDimensions.add(dimension); + } else + throw new IllegalStateException( + "Expected a constant key for instance field: " + + pointerKeyForInstanceField + + ", but got: " + + instanceFieldIK + + "."); + } + + LOGGER.info( + "Found possible shape dimensions: " + + tensorDimensions + + " for field: " + + pointerKeyForInstanceField + + " for source: " + + source + + "."); + + // Add the shape dimensions. + assert possibleDimensions[fieldIndex] == null + : "Duplicate field index: " + + fieldIndex + + " in object catalog: " + + objectCatalogPointsToSet + + "."; + + possibleDimensions[fieldIndex] = tensorDimensions; + LOGGER.fine( + "Added shape dimensions: " + + tensorDimensions + + " for field index: " + + fieldIndex + + "."); + } + + for (int i = 0; i < possibleDimensions.length; i++) + for (Dimension iDim : possibleDimensions[i]) { + @SuppressWarnings("unchecked") + Dimension[] dimensions = new Dimension[possibleDimensions.length]; + + dimensions[i] = iDim; + + for (int j = 0; j < possibleDimensions.length; j++) + if (i != j) + for (Dimension jDim : possibleDimensions[j]) dimensions[j] = jDim; + + ret.add(asList(dimensions)); + } + } else + throw new IllegalStateException( + "Expected a " + PythonTypes.list + " for the shape, but got: " + reference + "."); + } + + return ret; + } + + protected abstract Set>> getDefaultShapes(PropagationCallGraphBuilder builder); + + protected abstract int getValueNumberForShapeArgument(); + + /** + * Returns the possible shapes of the tensor returned by this generator. + * + * @param builder The {@link PropagationCallGraphBuilder} used to build the call graph. + * @return a set of shapes, where each shape is represented as a list of dimensions + */ + protected Set>> getShapes(PropagationCallGraphBuilder builder) { + PointerAnalysis pointerAnalysis = builder.getPointerAnalysis(); + + // Get the shape from the explicit argument. + // FIXME: Handle keyword arguments. + int shapeArgValueNum = this.getValueNumberForShapeArgument(); + + PointerKey pointerKey = + pointerAnalysis.getHeapModel().getPointerKeyForLocal(node, shapeArgValueNum); + OrdinalSet pointsToSet = pointerAnalysis.getPointsToSet(pointerKey); + + // If the argument shape is not specified. + if (pointsToSet.isEmpty()) return getDefaultShapes(builder); + else + // The shape points-to set is non-empty, meaning that the shape was explicitly set. + return getShapesFromShapeArgument(builder, pointsToSet); + } + + /** + * Returns the possible shapes of the tensor returned by this generator. The shape is inferred + * from the argument represented by the given value number. + * + * @param builder The {@link PropagationCallGraphBuilder} used to build the call graph. + * @param valueNumber The value number of the argument from which to infer the shape. + * @return A set of possible shapes of the tensor returned by this generator. + */ + protected Set>> getShapes( + PropagationCallGraphBuilder builder, int valueNumber) { + PointerAnalysis pointerAnalysis = builder.getPointerAnalysis(); + PointerKey valuePK = pointerAnalysis.getHeapModel().getPointerKeyForLocal(node, valueNumber); + OrdinalSet valuePointsToSet = pointerAnalysis.getPointsToSet(valuePK); + return getShapesOfValue(builder, valuePointsToSet); + } + + /** + * Returns the possible shapes of the tensor returned by this generator. + * + * @param builder The {@link PropagationCallGraphBuilder} used to build the call graph. + * @param pointsToSet The points-to set of the value from which the shape will be derived. + * @return A set of possible shapes of the tensor returned by this generator. + */ + private Set>> getShapesOfValue( + PropagationCallGraphBuilder builder, OrdinalSet valuePointsToSet) { + Set>> ret = HashSetFactory.make(); + PointerAnalysis pointerAnalysis = builder.getPointerAnalysis(); + + for (InstanceKey valueIK : valuePointsToSet) + if (valueIK instanceof ConstantKey) ret.add(emptyList()); // Scalar value. + else if (valueIK instanceof AllocationSiteInNode) { + AllocationSiteInNode asin = getAllocationSiteInNode(valueIK); + TypeReference reference = asin.getConcreteType().getReference(); + + if (reference.equals(list)) { + OrdinalSet objectCatalogPointsToSet = + pointerAnalysis.getPointsToSet( + ((AstPointerKeyFactory) builder.getPointerKeyFactory()) + .getPointerKeyForObjectCatalog(asin)); + + LOGGER.fine( + "The object catalog points-to set size is: " + objectCatalogPointsToSet.size() + "."); + + for (InstanceKey catalogIK : objectCatalogPointsToSet) { + ConstantKey constantKey = (ConstantKey) catalogIK; + Object constantKeyValue = constantKey.getValue(); + + Integer fieldIndex = (Integer) constantKeyValue; + + FieldReference subscript = + FieldReference.findOrCreate( + PythonTypes.Root, + findOrCreateAsciiAtom(fieldIndex.toString()), + PythonTypes.Root); + + IField f = builder.getClassHierarchy().resolveField(subscript); + LOGGER.fine("Found field: " + f); + + PointerKey pointerKeyForInstanceField = builder.getPointerKeyForInstanceField(asin, f); + LOGGER.fine( + "Found pointer key for instance field: " + pointerKeyForInstanceField + "."); + + OrdinalSet instanceFieldPointsToSet = + pointerAnalysis.getPointsToSet(pointerKeyForInstanceField); + LOGGER.fine("Points-to set for instance field: " + instanceFieldPointsToSet + "."); + + Set>> shapesOfField = + getShapesOfValue(builder, instanceFieldPointsToSet); + + for (List> shapeList : shapesOfField) { + List> shape = new ArrayList<>(); + + shape.add(new NumericDim(objectCatalogPointsToSet.size())); + shape.addAll(shapeList); + + ret.add(shape); + } + } + } else throw new IllegalStateException("Unknown type reference: " + reference + "."); + } else + throw new IllegalStateException( + "Expected a " + ConstantKey.class + " for value, but got: " + valueIK + "."); + + return ret; + } + + /** + * Returns the possible dtypes of the tensor returned by this generator. + * + * @param builder The {@link PropagationCallGraphBuilder} used to build the call graph. + * @param pointsToSet The points-to set of the dtype argument, which is expected to be a set of + * type literals. + * @return A set of possible dtypes of the tensor returned by this generator. + */ + protected EnumSet getDTypesFromShapeArgument( + PropagationCallGraphBuilder builder, Iterable pointsToSet) { + EnumSet ret = EnumSet.noneOf(DType.class); + PointerAnalysis pointerAnalysis = builder.getPointerAnalysis(); + + for (InstanceKey instanceKey : pointsToSet) { + IClass concreteType = instanceKey.getConcreteType(); + TypeReference typeReference = concreteType.getReference(); + + if (typeReference.equals(TensorFlowTypes.D_TYPE)) { + // we have a dtype. + // let's see if it's float32. + Set importNodes = builder.getCallGraph().getNodes(IMPORT); + + // find the import node from this file. + Optional importNode = + importNodes.stream() + .filter( + in -> { + ContextItem contextItem = in.getContext().get(CALL_STRING); + CallString cs = (CallString) contextItem; + + // We expect the first method in the call string to be the import. + assert cs.getMethods().length == 1 + : "Expected a single method in the call string, but got: " + + cs.getMethods().length + + " for node: " + + in; + + IMethod method = cs.getMethods()[0]; + + CallString nodeCS = (CallString) node.getContext().get(CALL_STRING); + + // We expect the first method in the call string to be the import. + assert nodeCS.getMethods().length == 1 + : "Expected a single method in the call string, but got: " + + nodeCS.getMethods().length + + " for node: " + + in; + + return method.equals(nodeCS.getMethods()[0]); + }) + .findFirst(); + + InstanceKey tensorFlowIK = + pointerAnalysis + .getHeapModel() + .getInstanceKeyForAllocation( + importNode.get(), NewSiteReference.make(0, TENSORFLOW)); + + FieldReference float32 = + FieldReference.findOrCreate( + PythonTypes.Root, findOrCreateAsciiAtom(FLOAT32.name().toLowerCase()), D_TYPE); + + IField float32Field = builder.getClassHierarchy().resolveField(float32); + + PointerKey float32PK = + pointerAnalysis + .getHeapModel() + .getPointerKeyForInstanceField(tensorFlowIK, float32Field); + + for (InstanceKey float32IK : pointerAnalysis.getPointsToSet(float32PK)) + if (float32IK.equals(instanceKey)) { + ret.add(FLOAT32); + LOGGER.info( + "Found dtype: " + + FLOAT32 + + " for source: " + + source + + " from dType: " + + instanceKey + + "."); + } else throw new IllegalStateException("Unknown dtype: " + instanceKey + "."); + } else + throw new IllegalStateException( + "Expected a " + + TensorFlowTypes.D_TYPE + + " for the dtype, but got: " + + typeReference + + "."); + } + + return ret; + } + + /** + * Returns a set of possible dtypes of the tensor returned by this generator when an explicit + * dtype isn't provided as an argument. + * + * @param builder The {@link PropagationCallGraphBuilder} used to build the call graph. + * @return The set of possible dtypes of the tensor returned by this generator when an explicit + * dtype isn't provided as an argument. + */ + protected abstract EnumSet getDefaultDTypes(PropagationCallGraphBuilder builder); + + /** + * Returns the value number for the dtype argument in the function call. + * + * @return The value number for the dtype argument in the function call or -1 if the dtype + * argument is not supported. + */ + protected abstract int getValueNumberForDTypeArgument(); + + protected EnumSet getDTypes(PropagationCallGraphBuilder builder) { + PointerAnalysis pointerAnalysis = builder.getPointerAnalysis(); + + int valNum = this.getValueNumberForDTypeArgument(); + OrdinalSet pointsToSet = null; + + if (valNum > 0) { + // The dtype is in an explicit argument. + // FIXME: Handle keyword arguments. + PointerKey pointerKey = pointerAnalysis.getHeapModel().getPointerKeyForLocal(node, valNum); + pointsToSet = pointerAnalysis.getPointsToSet(pointerKey); + } + + // If the argument dtype is not specified. + if (pointsToSet == null || pointsToSet.isEmpty()) return getDefaultDTypes(builder); + else + // The dtype points-to set is non-empty, meaning that the dtype was explicitly set. + return getDTypesFromShapeArgument(builder, pointsToSet); + } + + /** + * Returns the possible dtypes of the tensor returned by this generator. The dtype is inferred + * from the argument represented by the given value number. + * + * @param builder The {@link PropagationCallGraphBuilder} used to build the call graph. + * @param valueNumber The value number of the argument from which to infer the dtype. + * @return A set of possible dtypes of the tensor returned by this generator. + */ + protected EnumSet getDTypes(PropagationCallGraphBuilder builder, int valueNumber) { + PointerAnalysis pointerAnalysis = builder.getPointerAnalysis(); + PointerKey valuePK = pointerAnalysis.getHeapModel().getPointerKeyForLocal(node, valueNumber); + OrdinalSet valuePointsToSet = pointerAnalysis.getPointsToSet(valuePK); + return getDTypesOfValue(builder, valuePointsToSet); + } + + /** + * Returns the possible dtypes of the tensor returned by this generator. The dtype is inferred + * from the given points-to set. + * + * @param builder The {@link PropagationCallGraphBuilder} used to build the call graph. + * @param pointsToSet The points-to set of the value from which the dtype will be derived. + * @return A set of possible dtypes of the tensor returned by this generator. + */ + private EnumSet getDTypesOfValue( + PropagationCallGraphBuilder builder, OrdinalSet valuePointsToSet) { + EnumSet ret = EnumSet.noneOf(DType.class); + PointerAnalysis pointerAnalysis = builder.getPointerAnalysis(); + + for (InstanceKey valueIK : valuePointsToSet) + if (valueIK instanceof ConstantKey) { // It's a scalar value. + ConstantKey constantKey = (ConstantKey) valueIK; + Object value = constantKey.getValue(); + if (value instanceof Float || value instanceof Double) { + ret.add(FLOAT32); + LOGGER.info( + "Inferred dtype: " + + FLOAT32 + + " for source: " + + source + + " from value: " + + value + + "."); + } else if (value instanceof Integer || value instanceof Long) { + ret.add(INT32); + LOGGER.info( + "Inferred dtype: " + + INT32 + + " for source: " + + source + + " from value: " + + value + + "."); + } else if (value instanceof String) { + ret.add(STRING); + LOGGER.info( + "Inferred dtype: " + + STRING + + " for source: " + + source + + " from value: " + + value + + "."); + } else throw new IllegalStateException("Unknown constant type: " + value.getClass() + "."); + } else if (valueIK instanceof AllocationSiteInNode) { + AllocationSiteInNode asin = getAllocationSiteInNode(valueIK); + TypeReference reference = asin.getConcreteType().getReference(); + + if (reference.equals(list)) { + OrdinalSet objectCatalogPointsToSet = + pointerAnalysis.getPointsToSet( + ((AstPointerKeyFactory) builder.getPointerKeyFactory()) + .getPointerKeyForObjectCatalog(asin)); + + LOGGER.fine( + "The object catalog points-to set size is: " + objectCatalogPointsToSet.size() + "."); + + for (InstanceKey catalogIK : objectCatalogPointsToSet) { + ConstantKey constantKey = (ConstantKey) catalogIK; + Object constantKeyValue = constantKey.getValue(); + + Integer fieldIndex = (Integer) constantKeyValue; + + FieldReference subscript = + FieldReference.findOrCreate( + PythonTypes.Root, + findOrCreateAsciiAtom(fieldIndex.toString()), + PythonTypes.Root); + + IField f = builder.getClassHierarchy().resolveField(subscript); + LOGGER.fine("Found field: " + f); + + PointerKey pointerKeyForInstanceField = builder.getPointerKeyForInstanceField(asin, f); + LOGGER.fine( + "Found pointer key for instance field: " + pointerKeyForInstanceField + "."); + + OrdinalSet instanceFieldPointsToSet = + pointerAnalysis.getPointsToSet(pointerKeyForInstanceField); + LOGGER.fine("Points-to set for instance field: " + instanceFieldPointsToSet + "."); + + EnumSet dTypesOfField = getDTypesOfValue(builder, instanceFieldPointsToSet); + ret.addAll(dTypesOfField); + } + } else throw new IllegalStateException("Unknown type reference: " + reference + "."); + } else + // TODO: More cases. + throw new IllegalStateException( + "Expected a " + ConstantKey.class + " for value, but got: " + valueIK + "."); + return ret; + } +} diff --git a/com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/client/TensorGeneratorFactory.java b/com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/client/TensorGeneratorFactory.java new file mode 100644 index 000000000..10870338c --- /dev/null +++ b/com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/client/TensorGeneratorFactory.java @@ -0,0 +1,56 @@ +package com.ibm.wala.cast.python.ml.client; + +import com.ibm.wala.cast.python.types.PythonTypes; +import com.ibm.wala.cast.types.AstMethodReference; +import com.ibm.wala.ipa.callgraph.CGNode; +import com.ibm.wala.ipa.callgraph.propagation.LocalPointerKey; +import com.ibm.wala.ipa.callgraph.propagation.PointerKey; +import com.ibm.wala.ipa.callgraph.propagation.PointsToSetVariable; +import com.ibm.wala.types.MethodReference; +import com.ibm.wala.types.TypeName; +import com.ibm.wala.types.TypeReference; +import java.util.logging.Logger; + +public class TensorGeneratorFactory { + + private static final Logger LOGGER = Logger.getLogger(TensorGeneratorFactory.class.getName()); + + /** https://www.tensorflow.org/api_docs/python/tf/ones. */ + private static final MethodReference ONES = + MethodReference.findOrCreate( + TypeReference.findOrCreate( + PythonTypes.pythonLoader, TypeName.string2TypeName("Ltensorflow/functions/ones")), + AstMethodReference.fnSelector); + + /** https://www.tensorflow.org/api_docs/python/tf/constant. */ + private static final MethodReference CONSTANT = + MethodReference.findOrCreate( + TypeReference.findOrCreate( + PythonTypes.pythonLoader, TypeName.string2TypeName("Ltensorflow/functions/constant")), + AstMethodReference.fnSelector); + + /** https://www.tensorflow.org/api_docs/python/tf/range. */ + private static final MethodReference RANGE = + MethodReference.findOrCreate( + TypeReference.findOrCreate( + PythonTypes.pythonLoader, TypeName.string2TypeName("Ltensorflow/functions/range")), + AstMethodReference.fnSelector); + + public static TensorGenerator getGenerator(PointsToSetVariable source) { + // Get the pointer key for the source. + PointerKey pointerKey = source.getPointerKey(); + + LocalPointerKey localPointerKey = (LocalPointerKey) pointerKey; + CGNode node = localPointerKey.getNode(); + + TypeReference calledFunction = node.getMethod().getDeclaringClass().getReference(); + LOGGER.info("Getting tensor generator for call to: " + calledFunction.getName() + "."); + + if (calledFunction.equals(ONES.getDeclaringClass())) return new Ones(source, node); + else if (calledFunction.equals(CONSTANT.getDeclaringClass())) return new Constant(source, node); + else if (calledFunction.equals(RANGE.getDeclaringClass())) return new Range(source, node); + else + throw new IllegalArgumentException( + "Unknown call: " + calledFunction + " for source: " + source + "."); + } +} diff --git a/com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/types/TensorFlowTypes.java b/com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/types/TensorFlowTypes.java index dbe791c2b..f0a392aa7 100644 --- a/com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/types/TensorFlowTypes.java +++ b/com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/types/TensorFlowTypes.java @@ -11,11 +11,34 @@ */ public class TensorFlowTypes extends PythonTypes { + /** + * Defined data types used in TensorFlow. + * + * @see TensorFlow + * dtypes. + * @author Raffi Khatchadourian + */ + public enum DType { + FLOAT32, + FLOAT64, + INT32, + INT64, + STRING; + } + public static final TypeReference TENSORFLOW = TypeReference.findOrCreate(pythonLoader, TypeName.findOrCreate("Ltensorflow")); public static final TypeReference DATASET = TypeReference.findOrCreate(pythonLoader, TypeName.findOrCreate("Ltensorflow/data/Dataset")); + /** + * Represents the TensorFlow data type. + * + * @see TensorFlow DType. + */ + public static final TypeReference D_TYPE = + TypeReference.findOrCreate(pythonLoader, TypeName.findOrCreate("Ltensorflow/dtypes/DType")); + private TensorFlowTypes() {} } diff --git a/com.ibm.wala.cast.python.test/data/tf2_test_add116.py b/com.ibm.wala.cast.python.test/data/tf2_test_add116.py new file mode 100644 index 000000000..55ebbe20e --- /dev/null +++ b/com.ibm.wala.cast.python.test/data/tf2_test_add116.py @@ -0,0 +1,17 @@ +import tensorflow as tf + + +def add(a, b): + assert a.shape == (1, 2), f"Expected shape (1, 2), got {a.shape}" + assert b.shape == (2, 2), f"Expected shape (2, 2), got {b.shape}" + + assert a.dtype == tf.float32, f"Expected dtype float32, got {a.dtype}" + assert b.dtype == tf.float32, f"Expected dtype float32, got {b.dtype}" + + return a + b + + +c = add(tf.ones([1, 2], tf.float32), tf.ones([2, 2], tf.float32)) + +assert c.shape == (2, 2), f"Expected shape (2, 2), got {c.shape}" +assert c.dtype == tf.float32, f"Expected dtype float32, got {c.dtype}" diff --git a/com.ibm.wala.cast.python.test/data/tf2_test_add117.py b/com.ibm.wala.cast.python.test/data/tf2_test_add117.py new file mode 100644 index 000000000..fe3255a50 --- /dev/null +++ b/com.ibm.wala.cast.python.test/data/tf2_test_add117.py @@ -0,0 +1,14 @@ +import tensorflow as tf +import random + + +def add(a, b): + return a + b + + +if random.random() < 0.5: + a = 1 +else: + a = 3 + +c = add(tf.ones([a, 2]), tf.ones([2, 2])) diff --git a/com.ibm.wala.cast.python.test/data/tf2_test_add7.py b/com.ibm.wala.cast.python.test/data/tf2_test_add7.py index a21250c2e..dc4eb2017 100644 --- a/com.ibm.wala.cast.python.test/data/tf2_test_add7.py +++ b/com.ibm.wala.cast.python.test/data/tf2_test_add7.py @@ -2,7 +2,16 @@ def add(a, b): + assert a.shape == (1, 2), f"Expected shape (1, 2), got {a.shape}" + assert b.shape == (2, 2), f"Expected shape (2, 2), got {b.shape}" + + assert a.dtype == tf.float32, f"Expected dtype float32, got {a.dtype}" + assert b.dtype == tf.float32, f"Expected dtype float32, got {b.dtype}" + return a + b -c = add(tf.ones([1, 2]), tf.ones([2, 2])) # [[2., 2.], [2., 2.]] +c = add(tf.ones([1, 2]), tf.ones([2, 2])) + +assert c.shape == (2, 2), f"Expected shape (2, 2), got {c.shape}" +assert c.dtype == tf.float32, f"Expected dtype float32, got {c.dtype}" diff --git a/com.ibm.wala.cast.python.test/data/tf2_test_decorator12.py b/com.ibm.wala.cast.python.test/data/tf2_test_decorator12.py new file mode 100644 index 000000000..8d6329d8f --- /dev/null +++ b/com.ibm.wala.cast.python.test/data/tf2_test_decorator12.py @@ -0,0 +1,14 @@ +import tensorflow as tf + + +@tf.function(input_signature=(tf.TensorSpec(shape=[None], dtype=tf.float32),)) +@tf.function(reduce_retracing=True) +def returned(a): + return a + + +a = tf.constant([1, 1.0]) +b = returned(a) + +assert a.shape == (2,) +assert a.dtype == tf.float32 diff --git a/com.ibm.wala.cast.python.test/data/tf2_test_decorator2.py b/com.ibm.wala.cast.python.test/data/tf2_test_decorator2.py index 3df833404..60531f3e7 100644 --- a/com.ibm.wala.cast.python.test/data/tf2_test_decorator2.py +++ b/com.ibm.wala.cast.python.test/data/tf2_test_decorator2.py @@ -7,4 +7,8 @@ def returned(a): a = tf.range(5) + +assert a.shape == (5,) +assert a.dtype == tf.int32 + b = returned(a) diff --git a/com.ibm.wala.cast.python.test/data/tf2_test_decorator3.py b/com.ibm.wala.cast.python.test/data/tf2_test_decorator3.py index e65180b79..f7dff82fc 100644 --- a/com.ibm.wala.cast.python.test/data/tf2_test_decorator3.py +++ b/com.ibm.wala.cast.python.test/data/tf2_test_decorator3.py @@ -9,3 +9,6 @@ def returned(a): a = tf.constant([1.0, 1.0]) b = returned(a) + +assert a.shape == (2,) +assert a.dtype == tf.float32 diff --git a/com.ibm.wala.cast.python.test/data/tf2_test_function.py b/com.ibm.wala.cast.python.test/data/tf2_test_function.py index 05006b94c..c4e000163 100644 --- a/com.ibm.wala.cast.python.test/data/tf2_test_function.py +++ b/com.ibm.wala.cast.python.test/data/tf2_test_function.py @@ -8,7 +8,11 @@ def func2(t): @tf.function def func(): a = tf.constant([[1.0, 2.0], [3.0, 4.0]]) + assert a.shape == (2, 2) + b = tf.constant([[1.0, 1.0], [0.0, 1.0]]) + assert b.shape == (2, 2) + c = tf.matmul(a, b) tensor = tf.experimental.numpy.ndarray(c.op, 0, tf.float32) func2(tensor) diff --git a/com.ibm.wala.cast.python.test/data/tf2_test_function10.py b/com.ibm.wala.cast.python.test/data/tf2_test_function10.py new file mode 100644 index 000000000..f4b4c9d45 --- /dev/null +++ b/com.ibm.wala.cast.python.test/data/tf2_test_function10.py @@ -0,0 +1,16 @@ +import tensorflow as tf + + +def func(t): + pass + + +a = tf.constant( + [ + [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]], + [[13, 14, 15, 16], [17, 18, 19, 20], [21, 22, 23, 24]], + ] +) +assert a.shape == (2, 3, 4) + +func(a) diff --git a/com.ibm.wala.cast.python.test/data/tf2_test_function11.py b/com.ibm.wala.cast.python.test/data/tf2_test_function11.py new file mode 100644 index 000000000..2e150de13 --- /dev/null +++ b/com.ibm.wala.cast.python.test/data/tf2_test_function11.py @@ -0,0 +1,17 @@ +import tensorflow as tf + + +def func(t): + pass + + +a = tf.constant( + [ + [[1, 2, 3], [5, 6, 7], [9, 10, 11]], + [[13, 14, 15], [17, 18, 19], [21, 22, 23]], + ] +) +assert a.shape == (2, 3, 3) +assert a.dtype == tf.int32 + +func(a) diff --git a/com.ibm.wala.cast.python.test/data/tf2_test_function12.py b/com.ibm.wala.cast.python.test/data/tf2_test_function12.py new file mode 100644 index 000000000..cf82dfd35 --- /dev/null +++ b/com.ibm.wala.cast.python.test/data/tf2_test_function12.py @@ -0,0 +1,21 @@ +import tensorflow as tf +from random import random + + +def func(t): + pass + + +n = random() + +a = None + +if n > 0.5: + a = tf.constant([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]) + assert a.shape == (3, 2) +else: + a = tf.constant([[1.0], [3.0]]) + assert a.shape == (2, 1) + +assert a.shape == (3, 2) or a.shape == (2, 1) +func(a) diff --git a/com.ibm.wala.cast.python.test/data/tf2_test_function13.py b/com.ibm.wala.cast.python.test/data/tf2_test_function13.py new file mode 100644 index 000000000..9149486d0 --- /dev/null +++ b/com.ibm.wala.cast.python.test/data/tf2_test_function13.py @@ -0,0 +1,19 @@ +import tensorflow as tf +from random import random + + +def func(t): + pass + + +n = random() + +if n > 0.5: + a = tf.constant([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]) + assert a.shape == (3, 2) +else: + a = tf.constant([[1.0], [3.0]]) + assert a.shape == (2, 1) + +assert a.shape == (3, 2) or a.shape == (2, 1) +func(a) diff --git a/com.ibm.wala.cast.python.test/data/tf2_test_function5.py b/com.ibm.wala.cast.python.test/data/tf2_test_function5.py new file mode 100644 index 000000000..375130982 --- /dev/null +++ b/com.ibm.wala.cast.python.test/data/tf2_test_function5.py @@ -0,0 +1,11 @@ +import tensorflow as tf + + +def func(t): + pass + + +a = tf.constant([[1.0, 2.0], [3.0, 4.0]]) +assert a.shape == (2, 2) + +func(a) diff --git a/com.ibm.wala.cast.python.test/data/tf2_test_function6.py b/com.ibm.wala.cast.python.test/data/tf2_test_function6.py new file mode 100644 index 000000000..61c99f02a --- /dev/null +++ b/com.ibm.wala.cast.python.test/data/tf2_test_function6.py @@ -0,0 +1,11 @@ +import tensorflow as tf + + +def func(t): + pass + + +a = tf.constant([[1.0], [3.0]]) +assert a.shape == (2, 1) + +func(a) diff --git a/com.ibm.wala.cast.python.test/data/tf2_test_function7.py b/com.ibm.wala.cast.python.test/data/tf2_test_function7.py new file mode 100644 index 000000000..3acf6acff --- /dev/null +++ b/com.ibm.wala.cast.python.test/data/tf2_test_function7.py @@ -0,0 +1,11 @@ +import tensorflow as tf + + +def func(t): + pass + + +a = tf.constant([1.0, 3.0]) +assert a.shape == (2,) + +func(a) diff --git a/com.ibm.wala.cast.python.test/data/tf2_test_function8.py b/com.ibm.wala.cast.python.test/data/tf2_test_function8.py new file mode 100644 index 000000000..ae30e4842 --- /dev/null +++ b/com.ibm.wala.cast.python.test/data/tf2_test_function8.py @@ -0,0 +1,19 @@ +import tensorflow as tf +from random import random + + +def func(t): + pass + + +n = random() + +if n > 0.5: + l = [[1.0], [3.0]] +else: + l = [1.0, 3.0] + +a = tf.constant(l) +assert a.shape == (2, 1) or a.shape == (2,) + +func(a) diff --git a/com.ibm.wala.cast.python.test/data/tf2_test_function9.py b/com.ibm.wala.cast.python.test/data/tf2_test_function9.py new file mode 100644 index 000000000..43c43b9e4 --- /dev/null +++ b/com.ibm.wala.cast.python.test/data/tf2_test_function9.py @@ -0,0 +1,11 @@ +import tensorflow as tf + + +def func(t): + pass + + +a = tf.constant([[1.0, 3.0]]) +assert a.shape == (1, 2) + +func(a) diff --git a/com.ibm.wala.cast.python.test/data/tf2_test_static_method13.py b/com.ibm.wala.cast.python.test/data/tf2_test_static_method13.py new file mode 100644 index 000000000..a60222983 --- /dev/null +++ b/com.ibm.wala.cast.python.test/data/tf2_test_static_method13.py @@ -0,0 +1,16 @@ +import tensorflow as tf + + +class MyClass: + + @staticmethod + def the_static_method(x): + assert isinstance(x, tf.Tensor) + + +a = tf.constant(1, tf.float32, (5,)) + +assert a.shape == (5,) +assert a.dtype == tf.float32 + +MyClass.the_static_method(a) diff --git a/com.ibm.wala.cast.python.test/data/tf2_test_static_method14.py b/com.ibm.wala.cast.python.test/data/tf2_test_static_method14.py new file mode 100644 index 000000000..2c2457989 --- /dev/null +++ b/com.ibm.wala.cast.python.test/data/tf2_test_static_method14.py @@ -0,0 +1,16 @@ +import tensorflow as tf + + +class MyClass: + + @staticmethod + def the_static_method(x): + assert isinstance(x, tf.Tensor) + + +a = tf.constant(1, tf.float32, ([1, 2])) + +assert a.shape == (1, 2) +assert a.dtype == tf.float32 + +MyClass.the_static_method(a) diff --git a/com.ibm.wala.cast.python.test/data/tf2_test_tensor_list.py b/com.ibm.wala.cast.python.test/data/tf2_test_tensor_list.py index 86167f7db..723ece779 100644 --- a/com.ibm.wala.cast.python.test/data/tf2_test_tensor_list.py +++ b/com.ibm.wala.cast.python.test/data/tf2_test_tensor_list.py @@ -7,5 +7,11 @@ def add(a, b): list = [tf.ones([1, 2]), tf.ones([2, 2])] +assert list[0].shape == (1, 2) +assert list[1].shape == (2, 2) + +assert list[0].dtype == tf.float32 +assert list[1].dtype == tf.float32 + for element in list: c = add(element, element) diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 000000000..2d2500d7c --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,7 @@ +[tool.black] +extend-exclude = ''' + /( + IDE + | jython3 + )/ +'''