-
Notifications
You must be signed in to change notification settings - Fork 0
Loading saved model fails #104
Description
Hi! I am following the new tutorial in the scvi-tools documentation. Everything works fine but I want to save the model and later reuse it to avoid training it each time. I do that by running:
model.save("models/mrvi_no_nuissanse", overwrite=True)
Then I try to load the saved model by running either:
model.load("models/mrvi_no_nuissanse")
or:
model.load("models/mrvi_no_nuissanse", adata)
However, in both cases, the following error is raised:
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
[07_mrvi_analysis.ipynb) Cell 19 line 1
----> [1](07_mrvi_analysis.ipynb#X32sdnNjb2RlLXJlbW90ZQ%3D%3D?line=0) model.load("models/mrvi_no_nuissanse", adata)
File [/lib/python3.10/site-packages/scvi/model/base/_base_model.py:693](/lib/python3.10/site-packages/scvi/model/base/_base_model.py:693), in BaseModelClass.load(cls, dir_path, adata, accelerator, device, prefix, backup_url)
[680](/lib/python3.10/site-packages/scvi/model/base/_base_model.py:680) load_adata = adata is None
[681](/lib/python3.10/site-packages/scvi/model/base/_base_model.py:681) _, _, device = parse_device_args(
[682](/lib/python3.10/site-packages/scvi/model/base/_base_model.py:682) accelerator=accelerator,
[683](/lib/python3.10/site-packages/scvi/model/base/_base_model.py:683) devices=device,
[684](/lib/python3.10/site-packages/scvi/model/base/_base_model.py:684) return_device="torch",
[685](/lib/python3.10/site-packages/scvi/model/base/_base_model.py:685) validate_single_device=True,
[686](/lib/python3.10/site-packages/scvi/model/base/_base_model.py:686) )
[688](/lib/python3.10/site-packages/scvi/model/base/_base_model.py:688) (
[689](/lib/python3.10/site-packages/scvi/model/base/_base_model.py:689) attr_dict,
[690](/lib/python3.10/site-packages/scvi/model/base/_base_model.py:690) var_names,
[691](/lib/python3.10/site-packages/scvi/model/base/_base_model.py:691) model_state_dict,
[692](/lib/python3.10/site-packages/scvi/model/base/_base_model.py:692) new_adata,
--> [693](/lib/python3.10/site-packages/scvi/model/base/_base_model.py:693) ) = _load_saved_files(
[694](/lib/python3.10/site-packages/scvi/model/base/_base_model.py:694) dir_path,
[695](/lib/python3.10/site-packages/scvi/model/base/_base_model.py:695) load_adata,
[696](/lib/python3.10/site-packages/scvi/model/base/_base_model.py:696) map_location=device,
[697](/lib/python3.10/site-packages/scvi/model/base/_base_model.py:697) prefix=prefix,
[698](/lib/python3.10/site-packages/scvi/model/base/_base_model.py:698) backup_url=backup_url,
[699](/lib/python3.10/site-packages/scvi/model/base/_base_model.py:699) )
[700](/lib/python3.10/site-packages/scvi/model/base/_base_model.py:700) adata = new_adata if new_adata is not None else adata
[702](/lib/python3.10/site-packages/scvi/model/base/_base_model.py:702) _validate_var_names(adata, var_names)
File [/lib/python3.10/site-packages/scvi/model/base/_save_load.py:71](/lib/python3.10/site-packages/scvi/model/base/_save_load.py:71), in _load_saved_files(dir_path, load_adata, prefix, map_location, backup_url)
[69](/lib/python3.10/site-packages/scvi/model/base/_save_load.py:69) try:
[70](/lib/python3.10/site-packages/scvi/model/base/_save_load.py:70) _download(backup_url, dir_path, model_file_name)
---> [71](/lib/python3.10/site-packages/scvi/model/base/_save_load.py:71) model = torch.load(model_path, map_location=map_location)
[72](/lib/python3.10/site-packages/scvi/model/base/_save_load.py:72) except FileNotFoundError as exc:
[73](/lib/python3.10/site-packages/scvi/model/base/_save_load.py:73) raise ValueError(
[74](/lib/python3.10/site-packages/scvi/model/base/_save_load.py:74) f"Failed to load model file at {model_path}. "
[75](/lib/python3.10/site-packages/scvi/model/base/_save_load.py:75) "If attempting to load a saved model from <v0.15.0, please use the util function "
[76](/lib/python3.10/site-packages/scvi/model/base/_save_load.py:76) "`convert_legacy_save` to convert to an updated format."
[77](/lib/python3.10/site-packages/scvi/model/base/_save_load.py:77) ) from exc
File [/lib/python3.10/site-packages/torch/serialization.py:1025](/lib/python3.10/site-packages/torch/serialization.py:1025), in load(f, map_location, pickle_module, weights_only, mmap, **pickle_load_args)
[1023](/lib/python3.10/site-packages/torch/serialization.py:1023) except RuntimeError as e:
[1024](/lib/python3.10/site-packages/torch/serialization.py:1024) raise pickle.UnpicklingError(UNSAFE_MESSAGE + str(e)) from None
-> [1025](/lib/python3.10/site-packages/torch/serialization.py:1025) return _load(opened_zipfile,
[1026](/lib/python3.10/site-packages/torch/serialization.py:1026) map_location,
[1027](/lib/python3.10/site-packages/torch/serialization.py:1027) pickle_module,
[1028](/lib/python3.10/site-packages/torch/serialization.py:1028) overall_storage=overall_storage,
[1029](/lib/python3.10/site-packages/torch/serialization.py:1029) **pickle_load_args)
[1030](/lib/python3.10/site-packages/torch/serialization.py:1030) if mmap:
[1031](/lib/python3.10/site-packages/torch/serialization.py:1031) f_name = "" if not isinstance(f, str) else f"{f}, "
File [lib/python3.10/site-packages/torch/serialization.py:1442](/lib/python3.10/site-packages/torch/serialization.py:1442), in _load(zip_file, map_location, pickle_module, pickle_file, overall_storage, **pickle_load_args)
[1439](/lib/python3.10/site-packages/torch/serialization.py:1439) return super().find_class(mod_name, name)
[1441](/lib/python3.10/site-packages/torch/serialization.py:1441) # Load the data (which may in turn use `persistent_load` to load tensors)
-> [1442](/lib/python3.10/site-packages/torch/serialization.py:1442) data_file = io.BytesIO(zip_file.get_record(pickle_file))
[1444](/lib/python3.10/site-packages/torch/serialization.py:1444) unpickler = UnpicklerWrapper(data_file, **pickle_load_args)
[1445](/lib/python3.10/site-packages/torch/serialization.py:1445) unpickler.persistent_load = persistent_load
RuntimeError: PytorchStreamReader failed locating file data.pkl: file not foundI guess it can be fixed by saving the data along with the model:
model.save("models/mrvi_no_nuissanse", overwrite=True, save_anndata=True)
However, it would require storing a copy of a large dataset without necessity. Could you please provide instructions on how to save and load the model? They would also fit nicely in the tutorial