diff --git a/nmt/utils/misc_utils.py b/nmt/utils/misc_utils.py index a680a5cf2..080a7f215 100644 --- a/nmt/utils/misc_utils.py +++ b/nmt/utils/misc_utils.py @@ -25,12 +25,13 @@ import time import numpy as np +from distutils.version import LooseVersion, StrictVersion import tensorflow as tf def check_tensorflow_version(): min_tf_version = "1.4.0-dev20171024" - if tf.__version__ < min_tf_version: + if LooseVersion(tf.__version__) < LooseVersion(min_tf_version): raise EnvironmentError("Tensorflow version must >= %s" % min_tf_version)