diff --git a/mnist-cnn/training/mnist-cnn.py b/mnist-cnn/training/mnist-cnn.py index 7ff71d5..21b29bf 100644 --- a/mnist-cnn/training/mnist-cnn.py +++ b/mnist-cnn/training/mnist-cnn.py @@ -156,7 +156,7 @@ def run_test(weights_file, test_file): output_names = ['prob'] output_names.extend(monitor_names) - param_dict = np.load(weights_file, encoding='latin1').item() + param_dict = np.load(weights_file, encoding='latin1', allow_pickle=True).item() predictor = OfflinePredictor(PredictConfig( model=Model(), session_init=DictRestore(param_dict),