Skip to content

Commit 772b35b

Browse files
committed
fix: remove trailing os sep in local pretrained model path
1 parent 2a61590 commit 772b35b

File tree

2 files changed

+25
-1
lines changed

2 files changed

+25
-1
lines changed

src/transformers/dynamic_module_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -375,7 +375,7 @@ def get_cached_module_file(
375375
local_files_only = True
376376

377377
# Download and cache module_file from the repo `pretrained_model_name_or_path` of grab it if it's a local file.
378-
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
378+
pretrained_model_name_or_path = str(pretrained_model_name_or_path).rstrip(os.sep)
379379
is_local = os.path.isdir(pretrained_model_name_or_path)
380380
if is_local:
381381
submodule = _sanitize_module_name(os.path.basename(pretrained_model_name_or_path))

tests/utils/test_dynamic_module_utils.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@
1616

1717
import pytest
1818

19+
import warnings
20+
21+
from transformers import AutoConfig
1922
from transformers.dynamic_module_utils import get_imports
2023

2124

@@ -127,3 +130,24 @@ def test_import_parsing(tmp_path, case):
127130

128131
parsed_imports = get_imports(tmp_file_path)
129132
assert parsed_imports == ["os"]
133+
134+
def test_local_path_with_and_without_trailing_slash(tmp_path):
135+
model_dir = tmp_path / "my_model"
136+
model_dir.mkdir()
137+
config_path = model_dir / "config.json"
138+
config_path.write_text("{}")
139+
path_no_slash = str(model_dir)
140+
path_with_slash = str(model_dir) + os.sep
141+
142+
with warnings.catch_warnings(record=True) as w1:
143+
warnings.simplefilter("always")
144+
cfg1 = AutoConfig.from_pretrained(path_no_slash)
145+
146+
with warnings.catch_warnings(record=True) as w2:
147+
warnings.simplefilter("always")
148+
cfg2 = AutoConfig.from_pretrained(path_with_slash)
149+
150+
assert isinstance(cfg1, type(cfg2))
151+
assert len(w1) == 0
152+
assert len(w2) == 0
153+

0 commit comments

Comments
 (0)