diff --git a/jDAS/__init__.py b/jDAS/__init__.py index abd8156..1ce0eea 100644 --- a/jDAS/__init__.py +++ b/jDAS/__init__.py @@ -1,6 +1,7 @@ import sys import os -import keras +os.environ.setdefault("TF_USE_LEGACY_KERAS", "1") +import tensorflow as tf import numpy as np import gc @@ -32,7 +33,7 @@ def load_model(self, model_path=None): if model_path is None: model_path = os.path.join(cwd, "..", "models", "pretrained_model.h5") - self.model = keras.models.load_model(model_path) + self.model = tf.keras.models.load_model(model_path, compile=False) return self.model diff --git a/jDAS/filters.py b/jDAS/filters.py index 2bbee4c..e86f5a4 100644 --- a/jDAS/filters.py +++ b/jDAS/filters.py @@ -1,4 +1,9 @@ -from scipy.signal import tukey, butter, filtfilt, sosfiltfilt +from scipy.signal import butter, filtfilt, sosfiltfilt +try: + # SciPy >= 1.12 keeps window functions under scipy.signal.windows + from scipy.signal.windows import tukey +except ImportError: # pragma: no cover - fallback for older SciPy + from scipy.signal import tukey import scipy.fft # Scipy's FFT package is faster than Numpy's @@ -108,4 +113,4 @@ def taper_filter(arr, fmin, fmax, samp, order=2, mode="sos"): sos = _butter_bandpass(fmin, fmax, samp, order, mode) arr_filt = sosfiltfilt(sos, arr_wind, axis=-1) - return arr_filt \ No newline at end of file + return arr_filt