From 9c008a5732b05d61794da983061f11b09f38a323 Mon Sep 17 00:00:00 2001 From: Raffi Khatchadourian Date: Wed, 6 Dec 2023 16:19:18 -0500 Subject: [PATCH 1/2] Add preliminary example. --- .../python/ml/test/TestTensorflowModel.java | 1 + .../data/tf2_test_callbacks3.py | 20 +++++++++++++++++++ 2 files changed, 21 insertions(+) create mode 100644 com.ibm.wala.cast.python.test/data/tf2_test_callbacks3.py diff --git a/com.ibm.wala.cast.python.ml.test/source/com/ibm/wala/cast/python/ml/test/TestTensorflowModel.java b/com.ibm.wala.cast.python.ml.test/source/com/ibm/wala/cast/python/ml/test/TestTensorflowModel.java index e075ac588..9ea682805 100644 --- a/com.ibm.wala.cast.python.ml.test/source/com/ibm/wala/cast/python/ml/test/TestTensorflowModel.java +++ b/com.ibm.wala.cast.python.ml.test/source/com/ibm/wala/cast/python/ml/test/TestTensorflowModel.java @@ -214,6 +214,7 @@ public void testTf2() testTf2("tf2_test_model_call4.py", "SequentialModel.__call__", 1, 4, 3); testTf2("tf2_test_callbacks.py", "replica_fn", 1, 3, 2); testTf2("tf2_test_callbacks2.py", "replica_fn", 1, 4, 2); + testTf2("tf2_test_callbacks3.py", "dataset_fn", 0, 0); } private void testTf2( diff --git a/com.ibm.wala.cast.python.test/data/tf2_test_callbacks3.py b/com.ibm.wala.cast.python.test/data/tf2_test_callbacks3.py new file mode 100644 index 000000000..ba6e27d84 --- /dev/null +++ b/com.ibm.wala.cast.python.test/data/tf2_test_callbacks3.py @@ -0,0 +1,20 @@ +# From https://www.tensorflow.org/tutorials/distribute/input#tfdistributestrategydistribute_datasets_from_function. + +import tensorflow as tf + +global_batch_size = 16 +strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]) + + +def dataset_fn(input_context): + batch_size = input_context.get_per_replica_batch_size(global_batch_size) + dataset = tf.data.Dataset.from_tensors(([1.], [1.])).repeat(64).batch(16) + dataset = dataset.shard( + input_context.num_input_pipelines, input_context.input_pipeline_id) + dataset = dataset.batch(batch_size) + dataset = dataset.prefetch(2) # This prefetches 2 batches per device. + return dataset + + +dist_dataset = strategy.distribute_datasets_from_function(dataset_fn) +print(dist_dataset) From 5d778cf9694590b3386a287fa25c010caec89a4f Mon Sep 17 00:00:00 2001 From: Raffi Khatchadourian Date: Tue, 5 Mar 2024 10:33:53 -0500 Subject: [PATCH 2/2] Apply spotless. --- .../data/tf2_test_callbacks3.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/com.ibm.wala.cast.python.test/data/tf2_test_callbacks3.py b/com.ibm.wala.cast.python.test/data/tf2_test_callbacks3.py index ba6e27d84..570806c8e 100644 --- a/com.ibm.wala.cast.python.test/data/tf2_test_callbacks3.py +++ b/com.ibm.wala.cast.python.test/data/tf2_test_callbacks3.py @@ -7,13 +7,14 @@ def dataset_fn(input_context): - batch_size = input_context.get_per_replica_batch_size(global_batch_size) - dataset = tf.data.Dataset.from_tensors(([1.], [1.])).repeat(64).batch(16) - dataset = dataset.shard( - input_context.num_input_pipelines, input_context.input_pipeline_id) - dataset = dataset.batch(batch_size) - dataset = dataset.prefetch(2) # This prefetches 2 batches per device. - return dataset + batch_size = input_context.get_per_replica_batch_size(global_batch_size) + dataset = tf.data.Dataset.from_tensors(([1.0], [1.0])).repeat(64).batch(16) + dataset = dataset.shard( + input_context.num_input_pipelines, input_context.input_pipeline_id + ) + dataset = dataset.batch(batch_size) + dataset = dataset.prefetch(2) # This prefetches 2 batches per device. + return dataset dist_dataset = strategy.distribute_datasets_from_function(dataset_fn)