Skip to content
This repository was archived by the owner on Oct 9, 2024. It is now read-only.
This repository was archived by the owner on Oct 9, 2024. It is now read-only.

Loading saved model fails #104

@VladimirShitov

Description

@VladimirShitov

Hi! I am following the new tutorial in the scvi-tools documentation. Everything works fine but I want to save the model and later reuse it to avoid training it each time. I do that by running:
model.save("models/mrvi_no_nuissanse", overwrite=True)

Then I try to load the saved model by running either:
model.load("models/mrvi_no_nuissanse")
or:
model.load("models/mrvi_no_nuissanse", adata)

However, in both cases, the following error is raised:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
[07_mrvi_analysis.ipynb) Cell 19 line 1
----> [1](07_mrvi_analysis.ipynb#X32sdnNjb2RlLXJlbW90ZQ%3D%3D?line=0) model.load("models/mrvi_no_nuissanse", adata)

File [/lib/python3.10/site-packages/scvi/model/base/_base_model.py:693](/lib/python3.10/site-packages/scvi/model/base/_base_model.py:693), in BaseModelClass.load(cls, dir_path, adata, accelerator, device, prefix, backup_url)
    [680](/lib/python3.10/site-packages/scvi/model/base/_base_model.py:680) load_adata = adata is None
    [681](/lib/python3.10/site-packages/scvi/model/base/_base_model.py:681) _, _, device = parse_device_args(
    [682](/lib/python3.10/site-packages/scvi/model/base/_base_model.py:682)     accelerator=accelerator,
    [683](/lib/python3.10/site-packages/scvi/model/base/_base_model.py:683)     devices=device,
    [684](/lib/python3.10/site-packages/scvi/model/base/_base_model.py:684)     return_device="torch",
    [685](/lib/python3.10/site-packages/scvi/model/base/_base_model.py:685)     validate_single_device=True,
    [686](/lib/python3.10/site-packages/scvi/model/base/_base_model.py:686) )
    [688](/lib/python3.10/site-packages/scvi/model/base/_base_model.py:688) (
    [689](/lib/python3.10/site-packages/scvi/model/base/_base_model.py:689)     attr_dict,
    [690](/lib/python3.10/site-packages/scvi/model/base/_base_model.py:690)     var_names,
    [691](/lib/python3.10/site-packages/scvi/model/base/_base_model.py:691)     model_state_dict,
    [692](/lib/python3.10/site-packages/scvi/model/base/_base_model.py:692)     new_adata,
--> [693](/lib/python3.10/site-packages/scvi/model/base/_base_model.py:693) ) = _load_saved_files(
    [694](/lib/python3.10/site-packages/scvi/model/base/_base_model.py:694)     dir_path,
    [695](/lib/python3.10/site-packages/scvi/model/base/_base_model.py:695)     load_adata,
    [696](/lib/python3.10/site-packages/scvi/model/base/_base_model.py:696)     map_location=device,
    [697](/lib/python3.10/site-packages/scvi/model/base/_base_model.py:697)     prefix=prefix,
    [698](/lib/python3.10/site-packages/scvi/model/base/_base_model.py:698)     backup_url=backup_url,
    [699](/lib/python3.10/site-packages/scvi/model/base/_base_model.py:699) )
    [700](/lib/python3.10/site-packages/scvi/model/base/_base_model.py:700) adata = new_adata if new_adata is not None else adata
    [702](/lib/python3.10/site-packages/scvi/model/base/_base_model.py:702) _validate_var_names(adata, var_names)

File [/lib/python3.10/site-packages/scvi/model/base/_save_load.py:71](/lib/python3.10/site-packages/scvi/model/base/_save_load.py:71), in _load_saved_files(dir_path, load_adata, prefix, map_location, backup_url)
     [69](/lib/python3.10/site-packages/scvi/model/base/_save_load.py:69) try:
     [70](/lib/python3.10/site-packages/scvi/model/base/_save_load.py:70)     _download(backup_url, dir_path, model_file_name)
---> [71](/lib/python3.10/site-packages/scvi/model/base/_save_load.py:71)     model = torch.load(model_path, map_location=map_location)
     [72](/lib/python3.10/site-packages/scvi/model/base/_save_load.py:72) except FileNotFoundError as exc:
     [73](/lib/python3.10/site-packages/scvi/model/base/_save_load.py:73)     raise ValueError(
     [74](/lib/python3.10/site-packages/scvi/model/base/_save_load.py:74)         f"Failed to load model file at {model_path}. "
     [75](/lib/python3.10/site-packages/scvi/model/base/_save_load.py:75)         "If attempting to load a saved model from <v0.15.0, please use the util function "
     [76](/lib/python3.10/site-packages/scvi/model/base/_save_load.py:76)         "`convert_legacy_save` to convert to an updated format."
     [77](/lib/python3.10/site-packages/scvi/model/base/_save_load.py:77)     ) from exc

File [/lib/python3.10/site-packages/torch/serialization.py:1025](/lib/python3.10/site-packages/torch/serialization.py:1025), in load(f, map_location, pickle_module, weights_only, mmap, **pickle_load_args)
   [1023](/lib/python3.10/site-packages/torch/serialization.py:1023)             except RuntimeError as e:
   [1024](/lib/python3.10/site-packages/torch/serialization.py:1024)                 raise pickle.UnpicklingError(UNSAFE_MESSAGE + str(e)) from None
-> [1025](/lib/python3.10/site-packages/torch/serialization.py:1025)         return _load(opened_zipfile,
   [1026](/lib/python3.10/site-packages/torch/serialization.py:1026)                      map_location,
   [1027](/lib/python3.10/site-packages/torch/serialization.py:1027)                      pickle_module,
   [1028](/lib/python3.10/site-packages/torch/serialization.py:1028)                      overall_storage=overall_storage,
   [1029](/lib/python3.10/site-packages/torch/serialization.py:1029)                      **pickle_load_args)
   [1030](/lib/python3.10/site-packages/torch/serialization.py:1030) if mmap:
   [1031](/lib/python3.10/site-packages/torch/serialization.py:1031)     f_name = "" if not isinstance(f, str) else f"{f}, "

File [lib/python3.10/site-packages/torch/serialization.py:1442](/lib/python3.10/site-packages/torch/serialization.py:1442), in _load(zip_file, map_location, pickle_module, pickle_file, overall_storage, **pickle_load_args)
   [1439](/lib/python3.10/site-packages/torch/serialization.py:1439)         return super().find_class(mod_name, name)
   [1441](/lib/python3.10/site-packages/torch/serialization.py:1441) # Load the data (which may in turn use `persistent_load` to load tensors)
-> [1442](/lib/python3.10/site-packages/torch/serialization.py:1442) data_file = io.BytesIO(zip_file.get_record(pickle_file))
   [1444](/lib/python3.10/site-packages/torch/serialization.py:1444) unpickler = UnpicklerWrapper(data_file, **pickle_load_args)
   [1445](/lib/python3.10/site-packages/torch/serialization.py:1445) unpickler.persistent_load = persistent_load

RuntimeError: PytorchStreamReader failed locating file data.pkl: file not found

I guess it can be fixed by saving the data along with the model:
model.save("models/mrvi_no_nuissanse", overwrite=True, save_anndata=True)

However, it would require storing a copy of a large dataset without necessity. Could you please provide instructions on how to save and load the model? They would also fit nicely in the tutorial

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions