diff --git a/datasets/prepare_dataset.py b/datasets/prepare_dataset.py index bde90e9..579dfd1 100644 --- a/datasets/prepare_dataset.py +++ b/datasets/prepare_dataset.py @@ -4,7 +4,10 @@ import cv2 import numpy as np import json -import tensorflow as tf +try: + import tensorflow.compat.v1 as tf +except ImportError: + import tensorflow as tf from scipy.misc import imresize diff --git a/datasets/prepare_div2k_dataset.py b/datasets/prepare_div2k_dataset.py index 58dd404..044efe9 100644 --- a/datasets/prepare_div2k_dataset.py +++ b/datasets/prepare_div2k_dataset.py @@ -2,7 +2,10 @@ import argparse from tqdm import tqdm import cv2 -import tensorflow as tf +try: + import tensorflow.compat.v1 as tf +except ImportError: + import tensorflow as tf def bytes_feature(value): diff --git a/evaluate.py b/evaluate.py index 6444834..fb8b2a2 100644 --- a/evaluate.py +++ b/evaluate.py @@ -1,4 +1,7 @@ -import tensorflow as tf +try: + import tensorflow.compat.v1 as tf +except ImportError: + import tensorflow as tf import numpy as np import argparse import os diff --git a/generate_header_and_model.py b/generate_header_and_model.py index 0df6a08..792b664 100644 --- a/generate_header_and_model.py +++ b/generate_header_and_model.py @@ -1,4 +1,7 @@ -import tensorflow as tf +try: + import tensorflow.compat.v1 as tf +except ImportError: + import tensorflow as tf import numpy as np import argparse import os diff --git a/models/dataset.py b/models/dataset.py index f2da7d0..3fdf130 100644 --- a/models/dataset.py +++ b/models/dataset.py @@ -1,4 +1,7 @@ -import tensorflow as tf +try: + import tensorflow.compat.v1 as tf +except ImportError: + import tensorflow as tf from collections import OrderedDict diff --git a/models/image_warp.py b/models/image_warp.py index 6a2c4e4..803d0e3 100644 --- a/models/image_warp.py +++ b/models/image_warp.py @@ -1,4 +1,7 @@ -import tensorflow as tf +try: + import tensorflow.compat.v1 as tf +except ImportError: + import tensorflow as tf def image_warp(images, flow, name='image_warp'): @@ -39,4 +42,4 @@ def gather(y_coords, x_coords): interpolated = tf.reshape(interpolated, shape) - return interpolated \ No newline at end of file + return interpolated diff --git a/models/model.py b/models/model.py index a2cdb30..394627f 100644 --- a/models/model.py +++ b/models/model.py @@ -1,5 +1,8 @@ from abc import ABC, abstractmethod -import tensorflow as tf +try: + import tensorflow.compat.v1 as tf +except ImportError: + import tensorflow as tf import json from .dataset import Dataset diff --git a/models/model_espcn.py b/models/model_espcn.py index 08d50ec..bfc52bc 100644 --- a/models/model_espcn.py +++ b/models/model_espcn.py @@ -1,4 +1,7 @@ -import tensorflow as tf +try: + import tensorflow.compat.v1 as tf +except ImportError: + import tensorflow as tf from .model import Model diff --git a/models/model_srcnn.py b/models/model_srcnn.py index 6b2f0ed..bf47c3e 100644 --- a/models/model_srcnn.py +++ b/models/model_srcnn.py @@ -1,4 +1,7 @@ -import tensorflow as tf +try: + import tensorflow.compat.v1 as tf +except ImportError: + import tensorflow as tf from .model import Model diff --git a/models/model_vespcn.py b/models/model_vespcn.py index f1717f5..fa24eeb 100644 --- a/models/model_vespcn.py +++ b/models/model_vespcn.py @@ -1,4 +1,7 @@ -import tensorflow as tf +try: + import tensorflow.compat.v1 as tf +except ImportError: + import tensorflow as tf from .model import Model from .image_warp import image_warp diff --git a/models/model_vsrnet.py b/models/model_vsrnet.py index 5659424..d1325a5 100644 --- a/models/model_vsrnet.py +++ b/models/model_vsrnet.py @@ -1,4 +1,7 @@ -import tensorflow as tf +try: + import tensorflow.compat.v1 as tf +except ImportError: + import tensorflow as tf from .model import Model diff --git a/train.py b/train.py index 5f98c01..a92631b 100644 --- a/train.py +++ b/train.py @@ -1,4 +1,7 @@ -import tensorflow as tf +try: + import tensorflow.compat.v1 as tf +except ImportError: + import tensorflow as tf import argparse from tqdm import tqdm import os