diff --git a/com.ibm.wala.cast.python.ml.test/source/com/ibm/wala/cast/python/ml/test/TestTensorflow2Model.java b/com.ibm.wala.cast.python.ml.test/source/com/ibm/wala/cast/python/ml/test/TestTensorflow2Model.java index f5bb3124c..6e3454852 100644 --- a/com.ibm.wala.cast.python.ml.test/source/com/ibm/wala/cast/python/ml/test/TestTensorflow2Model.java +++ b/com.ibm.wala.cast.python.ml.test/source/com/ibm/wala/cast/python/ml/test/TestTensorflow2Model.java @@ -1266,6 +1266,12 @@ public void testCallbacks2() test("tf2_test_callbacks2.py", "replica_fn", 1, 1, 2); } + @Test + public void testCallbacks3() + throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException { + test("tf2_test_callbacks3.py", "dataset_fn", 0, 0); + } + @Test public void testGanTutorial() throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException { diff --git a/com.ibm.wala.cast.python.test/data/tf2_test_callbacks3.py b/com.ibm.wala.cast.python.test/data/tf2_test_callbacks3.py new file mode 100644 index 000000000..570806c8e --- /dev/null +++ b/com.ibm.wala.cast.python.test/data/tf2_test_callbacks3.py @@ -0,0 +1,21 @@ +# From https://www.tensorflow.org/tutorials/distribute/input#tfdistributestrategydistribute_datasets_from_function. + +import tensorflow as tf + +global_batch_size = 16 +strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]) + + +def dataset_fn(input_context): + batch_size = input_context.get_per_replica_batch_size(global_batch_size) + dataset = tf.data.Dataset.from_tensors(([1.0], [1.0])).repeat(64).batch(16) + dataset = dataset.shard( + input_context.num_input_pipelines, input_context.input_pipeline_id + ) + dataset = dataset.batch(batch_size) + dataset = dataset.prefetch(2) # This prefetches 2 batches per device. + return dataset + + +dist_dataset = strategy.distribute_datasets_from_function(dataset_fn) +print(dist_dataset)