diff --git a/requirements.txt b/requirements.txt index b606eddc5..d89b67553 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,7 +10,7 @@ certifi==2023.7.22 charset-normalizer==3.2.0 click==8.1.7 cmake==3.25.0 -datasets==2.14.4 +datasets==2.15.0 deepspeed==0.10.1 dill==0.3.7 docker-pycreds==0.4.0 diff --git a/trlx/models/modeling_base.py b/trlx/models/modeling_base.py index 26aa7a876..dbd1211c9 100644 --- a/trlx/models/modeling_base.py +++ b/trlx/models/modeling_base.py @@ -37,7 +37,7 @@ PeftModel, get_peft_config, get_peft_model, - prepare_model_for_int8_training, + prepare_model_for_kbit_training, ) @@ -217,7 +217,7 @@ def from_pretrained( # noqa: max-complexity ) if is_loaded_in_8bit: - base_model = prepare_model_for_int8_training( + base_model = prepare_model_for_kbit_training( base_model, **peft_int8_kwargs, ) @@ -255,7 +255,7 @@ def from_pretrained( # noqa: max-complexity if peft_config is not None: if is_loaded_in_8bit: - base_model = prepare_model_for_int8_training( + base_model = prepare_model_for_kbit_training( base_model, **peft_int8_kwargs, )