diff --git a/requirements.txt b/requirements.txt index b606eddc5..17af4ebaa 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,82 +1,83 @@ ---extra-index-url https://download.pytorch.org/whl/cu118 -accelerate==0.22.0 -aiohttp==3.8.5 +--extra-index-url https://download.pytorch.org/whl/cu121 +accelerate==0.24.1 +aiohttp==3.8.6 aiosignal==1.3.1 +annotated-types==0.6.0 appdirs==1.4.4 async-timeout==4.0.3 attrs==23.1.0 cattrs==23.1.2 certifi==2023.7.22 -charset-normalizer==3.2.0 +charset-normalizer==3.3.2 click==8.1.7 -cmake==3.25.0 -datasets==2.14.4 -deepspeed==0.10.1 +cuda-python==12.3.0 +datasets==2.14.6 +deepspeed==0.12.2 dill==0.3.7 docker-pycreds==0.4.0 -einops==0.6.1 +einops==0.7.0 exceptiongroup==1.1.3 filelock==3.9.0 frozenlist==1.4.0 -fsspec==2023.6.0 -gitdb==4.0.10 -GitPython==3.1.32 -grpcio==1.57.0 +fsspec==2023.4.0 +gitdb==4.0.11 +GitPython==3.1.40 hjson==3.1.0 -huggingface-hub==0.16.4 +huggingface-hub==0.17.3 idna==3.4 Jinja2==3.1.2 -jsonschema==4.19.0 +jsonschema==4.19.2 jsonschema-specifications==2023.7.1 -lit==15.0.7 markdown-it-py==3.0.0 MarkupSafe==2.1.2 mdurl==0.1.2 -mpmath==1.2.1 -msgpack==1.0.5 +mpmath==1.3.0 +msgpack==1.0.7 multidict==6.0.4 multiprocess==0.70.15 networkx==3.0 -ninja==1.11.1 -numpy==1.25.2 -packaging==23.1 -pandas==2.0.3 +ninja==1.11.1.1 +numpy==1.26.1 +packaging==23.2 +pandas==2.1.2 pathtools==0.1.2 -peft==0.5.0 -protobuf==4.24.2 -psutil==5.9.5 +peft==0.6.0 +protobuf==4.25.0 +psutil==5.9.6 py-cpuinfo==9.0.0 -pyarrow==13.0.0 -pydantic==1.10.12 +pyarrow==14.0.0 +pydantic==2.4.2 +pydantic_core==2.10.1 Pygments==2.16.1 +pynvml==11.5.0 python-dateutil==2.8.2 -python-rapidjson==1.10 -pytz==2023.3 +python-rapidjson==1.13 +pytz==2023.3.post1 PyYAML==6.0.1 -ray==2.6.3 +ray==2.8.0 referencing==0.30.2 -regex==2023.8.8 +regex==2023.10.3 requests==2.31.0 -rich==13.5.2 -rpds-py==0.9.2 -safetensors==0.3.3 -sentry-sdk==1.29.2 -setproctitle==1.3.2 +rich==13.6.0 +rpds-py==0.12.0 +safetensors==0.4.0 +sentry-sdk==1.34.0 +setproctitle==1.3.3 six==1.16.0 -smmap==5.0.0 -sympy==1.11.1 +smmap==5.0.1 +sympy==1.12 tabulate==0.9.0 -tokenizers==0.13.3 -torch==2.0.1+cu118 +tokenizers==0.14.1 +torch==2.1.0+cu121 torchtyping==0.1.4 tqdm==4.66.1 -transformers==4.32.0 -triton==2.0.0 -tritonclient==2.36.0 -typeguard==4.1.3 -typing_extensions==4.7.1 +transformers==4.35.0 +triton==2.1.0 +tritonclient==2.39.0 +typeguard==4.1.5 +typing_extensions==4.8.0 tzdata==2023.3 -urllib3==2.0.4 -wandb==0.15.8 -xxhash==3.3.0 +urllib3==2.0.7 +wandb==0.15.12 +xxhash==3.4.1 yarl==1.9.2 diff --git a/tests/test_peft.py b/tests/test_peft.py index ffe6f8bcb..37dd164af 100644 --- a/tests/test_peft.py +++ b/tests/test_peft.py @@ -10,7 +10,7 @@ import torch import transformers from peft import get_peft_config, get_peft_model -from peft.utils.config import PeftType, TaskType +from peft.utils import PeftType, TaskType from transformers import AutoConfig, AutoModelForCausalLM from trlx.data.configs import TokenizerConfig diff --git a/trlx/models/modeling_base.py b/trlx/models/modeling_base.py index 26aa7a876..330e28495 100644 --- a/trlx/models/modeling_base.py +++ b/trlx/models/modeling_base.py @@ -24,7 +24,7 @@ import torch import torch.nn as nn import transformers -from huggingface_hub import hf_hub_download +import huggingface_hub import trlx.utils.logging as logging from trlx.utils import is_peft_available @@ -155,8 +155,10 @@ def from_pretrained( # noqa: max-complexity call (e.g. `transformers.AutoModelForCausalLM.from_pretrained`) and the specific instance of the wrapped model. + NOTE: You must pass in arguments specific to the wrapped model as keyword arguments. """ + if kwargs is not None: peft_from_pretrained_kwargs = kwargs.pop("peft_from_pretrained_kwargs", {}) peft_int8_kwargs = kwargs.pop("peft_int8_kwargs", {}) @@ -273,42 +275,41 @@ def from_pretrained( # noqa: max-complexity model = cls(base_model, **wrapped_model_kwargs) if isinstance(pretrained_model_name_or_path, str): - filename = os.path.join(pretrained_model_name_or_path, "pytorch_model.bin") - sharded_index_filename = os.path.join(pretrained_model_name_or_path, "pytorch_model.bin.index.json") - is_sharded = False - - if not os.path.exists(filename): + if not os.path.exists(pretrained_model_name_or_path): try: - filename = hf_hub_download(pretrained_model_name_or_path, "pytorch_model.bin", revision=revision) - # Sharded - except Exception: - if os.path.exists(sharded_index_filename): - index_file_name = sharded_index_filename - else: - index_file_name = hf_hub_download( - pretrained_model_name_or_path, - "pytorch_model.bin.index.json", - revision=revision, - ) - with open(index_file_name, "r") as f: - index = json.load(f) - - # Load all weights from the shards - files_to_download = set(index["weight_map"].values()) - is_sharded = True - - if is_sharded: - # Merge each shard into a state dict - # TODO: Optimize this to avoid wasting RAM - state_dict = {} - for shard_file in files_to_download: - filename = os.path.join(pretrained_model_name_or_path, shard_file) - # Download if shard file doesn't exist locally - if not os.path.exists(filename): - filename = hf_hub_download(pretrained_model_name_or_path, shard_file, revision=revision) - state_dict.update(torch.load(filename, map_location="cpu")) + pretrained_model_name_or_path = huggingface_hub.snapshot_download(pretrained_model_name_or_path, revision=revision) + except huggingface_hub.utils._errors.RepositoryNotFoundError: + raise ValueError("Invalid `pretrained_model_name_or_path`. It should be a local path or a repository name.") + + sharded_index_filename = os.path.join(pretrained_model_name_or_path, "pytorch_model.bin.index.json") + sharded_safetensors_index_filename = os.path.join(pretrained_model_name_or_path, "model.safetensors.index.json") + + if os.path.exists(sharded_index_filename): + with open(shared_index_filename, "r") as f: + index = json.load(f) + shards = set(index["weight_map"].values()) + elif os.path.exists(sharded_safetensors_index_filename): + with open(sharded_safetensors_index_filename, "r") as f: + index = json.load(f) + shards = set(index["weight_map"].values()) else: - state_dict = torch.load(filename, map_location="cpu") + shard_filename = os.path.join(pretrained_model_name_or_path, "pytorch_model.bin") + shard_safetensors_filename = os.path.join(pretrained_model_name_or_path, "model.safetensors") + if os.path.exists(shard_filename): + shards = [shard_filename] + elif os.path.exists(shard_safetensors_filename): + shards = [shard_safetensors_filename] + + state_dict = {} + for shard in shards: + if shard.endswith(".safetensors"): + import safetensors + with safetensors.safe_open(os.path.join(pretrained_model_name_or_path, shard), framework="pt") as f: + for k in f.keys(): + state_dict[k] = f.get_tensor(k) + else: + state_dict.update(torch.load(shard, map_location="cpu")) + else: state_dict = pretrained_model_name_or_path.state_dict()