Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 48 additions & 47 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,82 +1,83 @@
--extra-index-url https://download.pytorch.org/whl/cu118
accelerate==0.22.0
aiohttp==3.8.5
--extra-index-url https://download.pytorch.org/whl/cu121
accelerate==0.24.1
aiohttp==3.8.6
aiosignal==1.3.1
annotated-types==0.6.0
appdirs==1.4.4
async-timeout==4.0.3
attrs==23.1.0
cattrs==23.1.2
certifi==2023.7.22
charset-normalizer==3.2.0
charset-normalizer==3.3.2
click==8.1.7
cmake==3.25.0
datasets==2.14.4
deepspeed==0.10.1
cuda-python==12.3.0
datasets==2.14.6
deepspeed==0.12.2
dill==0.3.7
docker-pycreds==0.4.0
einops==0.6.1
einops==0.7.0
exceptiongroup==1.1.3
filelock==3.9.0
frozenlist==1.4.0
fsspec==2023.6.0
gitdb==4.0.10
GitPython==3.1.32
grpcio==1.57.0
fsspec==2023.4.0
gitdb==4.0.11
GitPython==3.1.40
hjson==3.1.0
huggingface-hub==0.16.4
huggingface-hub==0.17.3
idna==3.4
Jinja2==3.1.2
jsonschema==4.19.0
jsonschema==4.19.2
jsonschema-specifications==2023.7.1
lit==15.0.7
markdown-it-py==3.0.0
MarkupSafe==2.1.2
mdurl==0.1.2
mpmath==1.2.1
msgpack==1.0.5
mpmath==1.3.0
msgpack==1.0.7
multidict==6.0.4
multiprocess==0.70.15
networkx==3.0
ninja==1.11.1
numpy==1.25.2
packaging==23.1
pandas==2.0.3
ninja==1.11.1.1
numpy==1.26.1
packaging==23.2
pandas==2.1.2
pathtools==0.1.2
peft==0.5.0
protobuf==4.24.2
psutil==5.9.5
peft==0.6.0
protobuf==4.25.0
psutil==5.9.6
py-cpuinfo==9.0.0
pyarrow==13.0.0
pydantic==1.10.12
pyarrow==14.0.0
pydantic==2.4.2
pydantic_core==2.10.1
Pygments==2.16.1
pynvml==11.5.0
python-dateutil==2.8.2
python-rapidjson==1.10
pytz==2023.3
python-rapidjson==1.13
pytz==2023.3.post1
PyYAML==6.0.1
ray==2.6.3
ray==2.8.0
referencing==0.30.2
regex==2023.8.8
regex==2023.10.3
requests==2.31.0
rich==13.5.2
rpds-py==0.9.2
safetensors==0.3.3
sentry-sdk==1.29.2
setproctitle==1.3.2
rich==13.6.0
rpds-py==0.12.0
safetensors==0.4.0
sentry-sdk==1.34.0
setproctitle==1.3.3
six==1.16.0
smmap==5.0.0
sympy==1.11.1
smmap==5.0.1
sympy==1.12
tabulate==0.9.0
tokenizers==0.13.3
torch==2.0.1+cu118
tokenizers==0.14.1
torch==2.1.0+cu121
torchtyping==0.1.4
tqdm==4.66.1
transformers==4.32.0
triton==2.0.0
tritonclient==2.36.0
typeguard==4.1.3
typing_extensions==4.7.1
transformers==4.35.0
triton==2.1.0
tritonclient==2.39.0
typeguard==4.1.5
typing_extensions==4.8.0
tzdata==2023.3
urllib3==2.0.4
wandb==0.15.8
xxhash==3.3.0
urllib3==2.0.7
wandb==0.15.12
xxhash==3.4.1
yarl==1.9.2
2 changes: 1 addition & 1 deletion tests/test_peft.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import torch
import transformers
from peft import get_peft_config, get_peft_model
from peft.utils.config import PeftType, TaskType
from peft.utils import PeftType, TaskType
from transformers import AutoConfig, AutoModelForCausalLM

from trlx.data.configs import TokenizerConfig
Expand Down
71 changes: 36 additions & 35 deletions trlx/models/modeling_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import torch
import torch.nn as nn
import transformers
from huggingface_hub import hf_hub_download
import huggingface_hub

import trlx.utils.logging as logging
from trlx.utils import is_peft_available
Expand Down Expand Up @@ -155,8 +155,10 @@ def from_pretrained( # noqa: max-complexity
call (e.g. `transformers.AutoModelForCausalLM.from_pretrained`) and the specific
instance of the wrapped model.


NOTE: You must pass in arguments specific to the wrapped model as keyword arguments.
"""

if kwargs is not None:
peft_from_pretrained_kwargs = kwargs.pop("peft_from_pretrained_kwargs", {})
peft_int8_kwargs = kwargs.pop("peft_int8_kwargs", {})
Expand Down Expand Up @@ -273,42 +275,41 @@ def from_pretrained( # noqa: max-complexity
model = cls(base_model, **wrapped_model_kwargs)

if isinstance(pretrained_model_name_or_path, str):
filename = os.path.join(pretrained_model_name_or_path, "pytorch_model.bin")
sharded_index_filename = os.path.join(pretrained_model_name_or_path, "pytorch_model.bin.index.json")
is_sharded = False

if not os.path.exists(filename):
if not os.path.exists(pretrained_model_name_or_path):
try:
filename = hf_hub_download(pretrained_model_name_or_path, "pytorch_model.bin", revision=revision)
# Sharded
except Exception:
if os.path.exists(sharded_index_filename):
index_file_name = sharded_index_filename
else:
index_file_name = hf_hub_download(
pretrained_model_name_or_path,
"pytorch_model.bin.index.json",
revision=revision,
)
with open(index_file_name, "r") as f:
index = json.load(f)

# Load all weights from the shards
files_to_download = set(index["weight_map"].values())
is_sharded = True

if is_sharded:
# Merge each shard into a state dict
# TODO: Optimize this to avoid wasting RAM
state_dict = {}
for shard_file in files_to_download:
filename = os.path.join(pretrained_model_name_or_path, shard_file)
# Download if shard file doesn't exist locally
if not os.path.exists(filename):
filename = hf_hub_download(pretrained_model_name_or_path, shard_file, revision=revision)
state_dict.update(torch.load(filename, map_location="cpu"))
pretrained_model_name_or_path = huggingface_hub.snapshot_download(pretrained_model_name_or_path, revision=revision)
except huggingface_hub.utils._errors.RepositoryNotFoundError:
raise ValueError("Invalid `pretrained_model_name_or_path`. It should be a local path or a repository name.")

sharded_index_filename = os.path.join(pretrained_model_name_or_path, "pytorch_model.bin.index.json")
sharded_safetensors_index_filename = os.path.join(pretrained_model_name_or_path, "model.safetensors.index.json")

if os.path.exists(sharded_index_filename):
with open(shared_index_filename, "r") as f:
index = json.load(f)
shards = set(index["weight_map"].values())
elif os.path.exists(sharded_safetensors_index_filename):
with open(sharded_safetensors_index_filename, "r") as f:
index = json.load(f)
shards = set(index["weight_map"].values())
else:
state_dict = torch.load(filename, map_location="cpu")
shard_filename = os.path.join(pretrained_model_name_or_path, "pytorch_model.bin")
shard_safetensors_filename = os.path.join(pretrained_model_name_or_path, "model.safetensors")
if os.path.exists(shard_filename):
shards = [shard_filename]
elif os.path.exists(shard_safetensors_filename):
shards = [shard_safetensors_filename]

state_dict = {}
for shard in shards:
if shard.endswith(".safetensors"):
import safetensors
with safetensors.safe_open(os.path.join(pretrained_model_name_or_path, shard), framework="pt") as f:
for k in f.keys():
state_dict[k] = f.get_tensor(k)
else:
state_dict.update(torch.load(shard, map_location="cpu"))

else:
state_dict = pretrained_model_name_or_path.state_dict()

Expand Down