diff --git a/llm_web_kit/model/html_classify/model.py b/llm_web_kit/model/html_classify/model.py
index 76ac4724..cdb1d637 100644
--- a/llm_web_kit/model/html_classify/model.py
+++ b/llm_web_kit/model/html_classify/model.py
@@ -1,5 +1,7 @@
import torch
+from llm_web_kit.model.resource_utils import import_transformer
+
class Markuplm():
def __init__(self, path, device):
@@ -16,13 +18,13 @@ def __init__(self, path, device):
self.tokenizer = self.load_tokenizer()
def load_tokenizer(self):
- from transformers import MarkupLMProcessor
+ transformers = import_transformer()
- return MarkupLMProcessor.from_pretrained(self.model_path)
+ return transformers.MarkupLMProcessor.from_pretrained(self.model_path)
def load_model(self):
- from transformers import MarkupLMForSequenceClassification
- model = MarkupLMForSequenceClassification.from_pretrained(self.model_path, num_labels=self.num_labels)
+ transformers = import_transformer()
+ model = transformers.MarkupLMForSequenceClassification.from_pretrained(self.model_path, num_labels=self.num_labels)
# load checkpoint
model.load_state_dict(torch.load(self.checkpoint_path, map_location=self.device))
model.to(self.device)
diff --git a/llm_web_kit/model/policical.py b/llm_web_kit/model/policical.py
index dff0e50a..7894c147 100644
--- a/llm_web_kit/model/policical.py
+++ b/llm_web_kit/model/policical.py
@@ -9,6 +9,7 @@
from llm_web_kit.libs.logger import mylogger as logger
from llm_web_kit.model.resource_utils import (CACHE_DIR, download_auto_file,
get_unzip_dir,
+ import_transformer,
singleton_resource_manager,
unzip_local_file)
@@ -18,8 +19,7 @@ class PoliticalDetector:
def __init__(self, model_path: str = None):
# import AutoTokenizer here to avoid isort error
# must set the HF_HOME to the CACHE_DIR at this point
- os.environ['HF_HOME'] = CACHE_DIR
- from transformers import AutoTokenizer
+ transformer = import_transformer()
if not model_path:
model_path = self.auto_download()
@@ -27,7 +27,7 @@ def __init__(self, model_path: str = None):
tokenizer_path = os.path.join(model_path, 'internlm2-chat-20b')
self.model = fasttext.load_model(model_bin_path)
- self.tokenizer = AutoTokenizer.from_pretrained(
+ self.tokenizer = transformer.AutoTokenizer.from_pretrained(
tokenizer_path, use_fast=False, trust_remote_code=True
)
diff --git a/llm_web_kit/model/porn_detector.py b/llm_web_kit/model/porn_detector.py
index d60c6b7a..c2b6d443 100644
--- a/llm_web_kit/model/porn_detector.py
+++ b/llm_web_kit/model/porn_detector.py
@@ -3,19 +3,23 @@
from typing import Dict, List, Union
import torch
-from transformers import AutoModelForSequenceClassification, AutoTokenizer
from llm_web_kit.config.cfg_reader import load_config
from llm_web_kit.libs.logger import mylogger as logger
from llm_web_kit.model.resource_utils import (CACHE_DIR, download_auto_file,
- get_unzip_dir, unzip_local_file)
+ get_unzip_dir,
+ import_transformer,
+ unzip_local_file)
+
+# from transformers import AutoModelForSequenceClassification, AutoTokenizer
class BertModel:
def __init__(self, model_path: str = None) -> None:
if not model_path:
model_path = self.auto_download()
- self.model = AutoModelForSequenceClassification.from_pretrained(
+ transformers_module = import_transformer()
+ self.model = transformers_module.AutoModelForSequenceClassification.from_pretrained(
os.path.join(model_path, 'porn_classifier/classifier_hf')
)
with open(
@@ -37,7 +41,7 @@ def __init__(self, model_path: str = None) -> None:
if hasattr(self.model, 'to_bettertransformer'):
self.model = self.model.to_bettertransformer()
- self.tokenizer = AutoTokenizer.from_pretrained(
+ self.tokenizer = transformers_module.AutoTokenizer.from_pretrained(
os.path.join(model_path, 'porn_classifier/classifier_hf')
)
self.tokenizer_config = {
diff --git a/llm_web_kit/model/resource_utils/__init__.py b/llm_web_kit/model/resource_utils/__init__.py
index 79ea734a..966c5555 100644
--- a/llm_web_kit/model/resource_utils/__init__.py
+++ b/llm_web_kit/model/resource_utils/__init__.py
@@ -1,6 +1,14 @@
from .download_assets import download_auto_file
from .singleton_resource_manager import singleton_resource_manager
from .unzip_ext import get_unzip_dir, unzip_local_file
-from .utils import CACHE_DIR, CACHE_TMP_DIR
+from .utils import CACHE_DIR, CACHE_TMP_DIR, import_transformer
-__all__ = ['download_auto_file', 'unzip_local_file', 'get_unzip_dir', 'CACHE_DIR', 'CACHE_TMP_DIR', 'singleton_resource_manager']
+__all__ = [
+ 'download_auto_file',
+ 'unzip_local_file',
+ 'get_unzip_dir',
+ 'CACHE_DIR',
+ 'CACHE_TMP_DIR',
+ 'singleton_resource_manager',
+ 'import_transformer',
+]
diff --git a/llm_web_kit/model/resource_utils/download_assets.py b/llm_web_kit/model/resource_utils/download_assets.py
index 9ea85f95..a0fe0ca7 100644
--- a/llm_web_kit/model/resource_utils/download_assets.py
+++ b/llm_web_kit/model/resource_utils/download_assets.py
@@ -201,6 +201,7 @@ def download_auto_file_core(
progress = tqdm(total=total_size, unit='iB', unit_scale=True)
# 使用临时目录确保原子性
+ os.makedirs(CACHE_TMP_DIR, exist_ok=True)
with tempfile.TemporaryDirectory(dir=CACHE_TMP_DIR) as temp_dir:
download_path = os.path.join(temp_dir, 'download_file')
try:
diff --git a/llm_web_kit/model/resource_utils/unzip_ext.py b/llm_web_kit/model/resource_utils/unzip_ext.py
index 66622595..6a4a8575 100644
--- a/llm_web_kit/model/resource_utils/unzip_ext.py
+++ b/llm_web_kit/model/resource_utils/unzip_ext.py
@@ -83,6 +83,9 @@ def unzip_local_file_core(
if os.path.exists(target_dir):
raise ModelResourceException(f'Target directory {target_dir} already exists')
+ # make sure the parent directory exists
+ os.makedirs(os.path.dirname(target_dir), exist_ok=True)
+
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
if password:
zip_ref.setpassword(password.encode())
diff --git a/llm_web_kit/model/resource_utils/utils.py b/llm_web_kit/model/resource_utils/utils.py
index 4ea78dda..42dacbf5 100644
--- a/llm_web_kit/model/resource_utils/utils.py
+++ b/llm_web_kit/model/resource_utils/utils.py
@@ -32,13 +32,6 @@ def decide_cache_dir():
CACHE_DIR, CACHE_TMP_DIR = decide_cache_dir()
-if not os.path.exists(CACHE_DIR):
- os.makedirs(CACHE_DIR, exist_ok=True)
-
-if not os.path.exists(CACHE_TMP_DIR):
- os.makedirs(CACHE_TMP_DIR, exist_ok=True)
-
-
def try_remove(path: str):
"""Attempt to remove a file by os.remove or to remove a directory by
shutil.rmtree and ignore exceptions."""
@@ -49,3 +42,11 @@ def try_remove(path: str):
os.remove(path)
except Exception:
pass
+
+
+def import_transformer():
+ os.environ['HF_HOME'] = CACHE_DIR
+ os.makedirs(CACHE_DIR, exist_ok=True)
+ import transformers
+
+ return transformers
diff --git a/tests/llm_web_kit/model/test_porn_detector.py b/tests/llm_web_kit/model/test_porn_detector.py
index dbb8da77..146d65fd 100644
--- a/tests/llm_web_kit/model/test_porn_detector.py
+++ b/tests/llm_web_kit/model/test_porn_detector.py
@@ -1,20 +1,28 @@
+import logging
import os
import sys
import unittest
from unittest import TestCase
from unittest.mock import MagicMock, mock_open, patch
+from transformers import logging as transformers_logging
+
+from llm_web_kit.model.porn_detector import BertModel # noqa: E402
+
current_file_path = os.path.abspath(__file__)
parent_dir_path = os.path.join(current_file_path, *[os.pardir] * 4)
normalized_path = os.path.normpath(parent_dir_path)
sys.path.append(normalized_path)
-from llm_web_kit.model.porn_detector import BertModel # noqa: E402
+
+transformers_logging.set_verbosity_error()
+
+logging.disable(logging.CRITICAL)
class TestBertModel(TestCase):
- @patch('llm_web_kit.model.porn_detector.AutoModelForSequenceClassification.from_pretrained')
- @patch('llm_web_kit.model.porn_detector.AutoTokenizer.from_pretrained')
+ @patch('transformers.AutoModelForSequenceClassification.from_pretrained')
+ @patch('transformers.AutoTokenizer.from_pretrained')
@patch('llm_web_kit.model.porn_detector.os.path.join')
@patch('llm_web_kit.model.porn_detector.open', new_callable=mock_open, read_data='{"cls_index": 0, "use_sigmoid": true, "max_tokens": 512, "device": "cuda"}')
@patch('llm_web_kit.model.porn_detector.BertModel.auto_download')
@@ -52,8 +60,8 @@ def test_init(self, mock_auto_download, mock_open, mock_os_path_join, mock_from_
}
)
- @patch('llm_web_kit.model.porn_detector.AutoModelForSequenceClassification.from_pretrained')
- @patch('llm_web_kit.model.porn_detector.AutoTokenizer.from_pretrained')
+ @patch('transformers.AutoModelForSequenceClassification.from_pretrained')
+ @patch('transformers.AutoTokenizer.from_pretrained')
@patch('llm_web_kit.model.porn_detector.os.path.join')
@patch('llm_web_kit.model.porn_detector.open', new_callable=mock_open, read_data='{"cls_index": 0, "use_sigmoid": true, "max_tokens": 512, "device": "cuda"}')
@patch('llm_web_kit.model.porn_detector.BertModel.auto_download')
@@ -107,8 +115,8 @@ def test_pre_process(self, mock_torch, mock_auto_download, mock_open, mock_os_pa
self.assertEqual(result['inputs']['input_ids'], expected_input_ids)
self.assertEqual(result['inputs']['attention_mask'], expected_attn_mask)
- @patch('llm_web_kit.model.porn_detector.AutoModelForSequenceClassification.from_pretrained')
- @patch('llm_web_kit.model.porn_detector.AutoTokenizer.from_pretrained')
+ @patch('transformers.AutoModelForSequenceClassification.from_pretrained')
+ @patch('transformers.AutoTokenizer.from_pretrained')
@patch('llm_web_kit.model.porn_detector.os.path.join')
@patch('llm_web_kit.model.porn_detector.open', new_callable=mock_open, read_data='{"cls_index": 0, "use_sigmoid": true, "max_tokens": 512, "device": "cuda"}')
@patch('llm_web_kit.model.porn_detector.BertModel.auto_download')