From d70af66fc5563e53f4188fd0353ba16afad4ae42 Mon Sep 17 00:00:00 2001 From: qiujiantao Date: Tue, 11 Mar 2025 14:44:00 +0800 Subject: [PATCH 1/3] =?UTF-8?q?=E9=87=8D=E6=9E=84=E6=A8=A1=E5=9E=8B?= =?UTF-8?q?=E5=AF=BC=E5=85=A5=E6=96=B9=E5=BC=8F=EF=BC=8C=E4=BD=BF=E7=94=A8?= =?UTF-8?q?import=5Ftransformer=E5=87=BD=E6=95=B0=E6=9B=BF=E4=BB=A3?= =?UTF-8?q?=E7=9B=B4=E6=8E=A5=E5=AF=BC=E5=85=A5transformers?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- llm_web_kit/model/policical.py | 6 ++--- llm_web_kit/model/porn_detector.py | 12 ++++++---- llm_web_kit/model/resource_utils/__init__.py | 12 ++++++++-- llm_web_kit/model/resource_utils/utils.py | 7 ++++++ tests/llm_web_kit/model/test_porn_detector.py | 22 +++++++++++++------ 5 files changed, 43 insertions(+), 16 deletions(-) diff --git a/llm_web_kit/model/policical.py b/llm_web_kit/model/policical.py index dff0e50a..7894c147 100644 --- a/llm_web_kit/model/policical.py +++ b/llm_web_kit/model/policical.py @@ -9,6 +9,7 @@ from llm_web_kit.libs.logger import mylogger as logger from llm_web_kit.model.resource_utils import (CACHE_DIR, download_auto_file, get_unzip_dir, + import_transformer, singleton_resource_manager, unzip_local_file) @@ -18,8 +19,7 @@ class PoliticalDetector: def __init__(self, model_path: str = None): # import AutoTokenizer here to avoid isort error # must set the HF_HOME to the CACHE_DIR at this point - os.environ['HF_HOME'] = CACHE_DIR - from transformers import AutoTokenizer + transformer = import_transformer() if not model_path: model_path = self.auto_download() @@ -27,7 +27,7 @@ def __init__(self, model_path: str = None): tokenizer_path = os.path.join(model_path, 'internlm2-chat-20b') self.model = fasttext.load_model(model_bin_path) - self.tokenizer = AutoTokenizer.from_pretrained( + self.tokenizer = transformer.AutoTokenizer.from_pretrained( tokenizer_path, use_fast=False, trust_remote_code=True ) diff --git a/llm_web_kit/model/porn_detector.py b/llm_web_kit/model/porn_detector.py index d60c6b7a..c2b6d443 100644 --- a/llm_web_kit/model/porn_detector.py +++ b/llm_web_kit/model/porn_detector.py @@ -3,19 +3,23 @@ from typing import Dict, List, Union import torch -from transformers import AutoModelForSequenceClassification, AutoTokenizer from llm_web_kit.config.cfg_reader import load_config from llm_web_kit.libs.logger import mylogger as logger from llm_web_kit.model.resource_utils import (CACHE_DIR, download_auto_file, - get_unzip_dir, unzip_local_file) + get_unzip_dir, + import_transformer, + unzip_local_file) + +# from transformers import AutoModelForSequenceClassification, AutoTokenizer class BertModel: def __init__(self, model_path: str = None) -> None: if not model_path: model_path = self.auto_download() - self.model = AutoModelForSequenceClassification.from_pretrained( + transformers_module = import_transformer() + self.model = transformers_module.AutoModelForSequenceClassification.from_pretrained( os.path.join(model_path, 'porn_classifier/classifier_hf') ) with open( @@ -37,7 +41,7 @@ def __init__(self, model_path: str = None) -> None: if hasattr(self.model, 'to_bettertransformer'): self.model = self.model.to_bettertransformer() - self.tokenizer = AutoTokenizer.from_pretrained( + self.tokenizer = transformers_module.AutoTokenizer.from_pretrained( os.path.join(model_path, 'porn_classifier/classifier_hf') ) self.tokenizer_config = { diff --git a/llm_web_kit/model/resource_utils/__init__.py b/llm_web_kit/model/resource_utils/__init__.py index 79ea734a..966c5555 100644 --- a/llm_web_kit/model/resource_utils/__init__.py +++ b/llm_web_kit/model/resource_utils/__init__.py @@ -1,6 +1,14 @@ from .download_assets import download_auto_file from .singleton_resource_manager import singleton_resource_manager from .unzip_ext import get_unzip_dir, unzip_local_file -from .utils import CACHE_DIR, CACHE_TMP_DIR +from .utils import CACHE_DIR, CACHE_TMP_DIR, import_transformer -__all__ = ['download_auto_file', 'unzip_local_file', 'get_unzip_dir', 'CACHE_DIR', 'CACHE_TMP_DIR', 'singleton_resource_manager'] +__all__ = [ + 'download_auto_file', + 'unzip_local_file', + 'get_unzip_dir', + 'CACHE_DIR', + 'CACHE_TMP_DIR', + 'singleton_resource_manager', + 'import_transformer', +] diff --git a/llm_web_kit/model/resource_utils/utils.py b/llm_web_kit/model/resource_utils/utils.py index 4ea78dda..63595eee 100644 --- a/llm_web_kit/model/resource_utils/utils.py +++ b/llm_web_kit/model/resource_utils/utils.py @@ -49,3 +49,10 @@ def try_remove(path: str): os.remove(path) except Exception: pass + + +def import_transformer(): + os.environ['HF_HOME'] = CACHE_DIR + import transformers + + return transformers diff --git a/tests/llm_web_kit/model/test_porn_detector.py b/tests/llm_web_kit/model/test_porn_detector.py index dbb8da77..146d65fd 100644 --- a/tests/llm_web_kit/model/test_porn_detector.py +++ b/tests/llm_web_kit/model/test_porn_detector.py @@ -1,20 +1,28 @@ +import logging import os import sys import unittest from unittest import TestCase from unittest.mock import MagicMock, mock_open, patch +from transformers import logging as transformers_logging + +from llm_web_kit.model.porn_detector import BertModel # noqa: E402 + current_file_path = os.path.abspath(__file__) parent_dir_path = os.path.join(current_file_path, *[os.pardir] * 4) normalized_path = os.path.normpath(parent_dir_path) sys.path.append(normalized_path) -from llm_web_kit.model.porn_detector import BertModel # noqa: E402 + +transformers_logging.set_verbosity_error() + +logging.disable(logging.CRITICAL) class TestBertModel(TestCase): - @patch('llm_web_kit.model.porn_detector.AutoModelForSequenceClassification.from_pretrained') - @patch('llm_web_kit.model.porn_detector.AutoTokenizer.from_pretrained') + @patch('transformers.AutoModelForSequenceClassification.from_pretrained') + @patch('transformers.AutoTokenizer.from_pretrained') @patch('llm_web_kit.model.porn_detector.os.path.join') @patch('llm_web_kit.model.porn_detector.open', new_callable=mock_open, read_data='{"cls_index": 0, "use_sigmoid": true, "max_tokens": 512, "device": "cuda"}') @patch('llm_web_kit.model.porn_detector.BertModel.auto_download') @@ -52,8 +60,8 @@ def test_init(self, mock_auto_download, mock_open, mock_os_path_join, mock_from_ } ) - @patch('llm_web_kit.model.porn_detector.AutoModelForSequenceClassification.from_pretrained') - @patch('llm_web_kit.model.porn_detector.AutoTokenizer.from_pretrained') + @patch('transformers.AutoModelForSequenceClassification.from_pretrained') + @patch('transformers.AutoTokenizer.from_pretrained') @patch('llm_web_kit.model.porn_detector.os.path.join') @patch('llm_web_kit.model.porn_detector.open', new_callable=mock_open, read_data='{"cls_index": 0, "use_sigmoid": true, "max_tokens": 512, "device": "cuda"}') @patch('llm_web_kit.model.porn_detector.BertModel.auto_download') @@ -107,8 +115,8 @@ def test_pre_process(self, mock_torch, mock_auto_download, mock_open, mock_os_pa self.assertEqual(result['inputs']['input_ids'], expected_input_ids) self.assertEqual(result['inputs']['attention_mask'], expected_attn_mask) - @patch('llm_web_kit.model.porn_detector.AutoModelForSequenceClassification.from_pretrained') - @patch('llm_web_kit.model.porn_detector.AutoTokenizer.from_pretrained') + @patch('transformers.AutoModelForSequenceClassification.from_pretrained') + @patch('transformers.AutoTokenizer.from_pretrained') @patch('llm_web_kit.model.porn_detector.os.path.join') @patch('llm_web_kit.model.porn_detector.open', new_callable=mock_open, read_data='{"cls_index": 0, "use_sigmoid": true, "max_tokens": 512, "device": "cuda"}') @patch('llm_web_kit.model.porn_detector.BertModel.auto_download') From 177263469dd2bf5def78d3d4d5cdc208fd471b92 Mon Sep 17 00:00:00 2001 From: qiujiantao Date: Tue, 11 Mar 2025 14:47:13 +0800 Subject: [PATCH 2/3] =?UTF-8?q?=E9=87=8D=E6=9E=84=E6=A8=A1=E5=9E=8B?= =?UTF-8?q?=E5=8A=A0=E8=BD=BD=E6=96=B9=E5=BC=8F=EF=BC=8C=E4=BD=BF=E7=94=A8?= =?UTF-8?q?import=5Ftransformer=E5=87=BD=E6=95=B0=E6=9B=BF=E4=BB=A3?= =?UTF-8?q?=E7=9B=B4=E6=8E=A5=E5=AF=BC=E5=85=A5transformers?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- llm_web_kit/model/html_classify/model.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/llm_web_kit/model/html_classify/model.py b/llm_web_kit/model/html_classify/model.py index 76ac4724..cdb1d637 100644 --- a/llm_web_kit/model/html_classify/model.py +++ b/llm_web_kit/model/html_classify/model.py @@ -1,5 +1,7 @@ import torch +from llm_web_kit.model.resource_utils import import_transformer + class Markuplm(): def __init__(self, path, device): @@ -16,13 +18,13 @@ def __init__(self, path, device): self.tokenizer = self.load_tokenizer() def load_tokenizer(self): - from transformers import MarkupLMProcessor + transformers = import_transformer() - return MarkupLMProcessor.from_pretrained(self.model_path) + return transformers.MarkupLMProcessor.from_pretrained(self.model_path) def load_model(self): - from transformers import MarkupLMForSequenceClassification - model = MarkupLMForSequenceClassification.from_pretrained(self.model_path, num_labels=self.num_labels) + transformers = import_transformer() + model = transformers.MarkupLMForSequenceClassification.from_pretrained(self.model_path, num_labels=self.num_labels) # load checkpoint model.load_state_dict(torch.load(self.checkpoint_path, map_location=self.device)) model.to(self.device) From 3d27885884e807ab34e57683873a96c5f243da51 Mon Sep 17 00:00:00 2001 From: qiujiantao Date: Tue, 11 Mar 2025 15:46:41 +0800 Subject: [PATCH 3/3] =?UTF-8?q?fix:=20=E7=A1=AE=E4=BF=9D=E7=BC=93=E5=AD=98?= =?UTF-8?q?=E7=9B=AE=E5=BD=95=E5=92=8C=E4=B8=B4=E6=97=B6=E7=9B=AE=E5=BD=95?= =?UTF-8?q?=E5=AD=98=E5=9C=A8?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- llm_web_kit/model/resource_utils/download_assets.py | 1 + llm_web_kit/model/resource_utils/unzip_ext.py | 3 +++ llm_web_kit/model/resource_utils/utils.py | 8 +------- 3 files changed, 5 insertions(+), 7 deletions(-) diff --git a/llm_web_kit/model/resource_utils/download_assets.py b/llm_web_kit/model/resource_utils/download_assets.py index 9ea85f95..a0fe0ca7 100644 --- a/llm_web_kit/model/resource_utils/download_assets.py +++ b/llm_web_kit/model/resource_utils/download_assets.py @@ -201,6 +201,7 @@ def download_auto_file_core( progress = tqdm(total=total_size, unit='iB', unit_scale=True) # 使用临时目录确保原子性 + os.makedirs(CACHE_TMP_DIR, exist_ok=True) with tempfile.TemporaryDirectory(dir=CACHE_TMP_DIR) as temp_dir: download_path = os.path.join(temp_dir, 'download_file') try: diff --git a/llm_web_kit/model/resource_utils/unzip_ext.py b/llm_web_kit/model/resource_utils/unzip_ext.py index 66622595..6a4a8575 100644 --- a/llm_web_kit/model/resource_utils/unzip_ext.py +++ b/llm_web_kit/model/resource_utils/unzip_ext.py @@ -83,6 +83,9 @@ def unzip_local_file_core( if os.path.exists(target_dir): raise ModelResourceException(f'Target directory {target_dir} already exists') + # make sure the parent directory exists + os.makedirs(os.path.dirname(target_dir), exist_ok=True) + with zipfile.ZipFile(zip_path, 'r') as zip_ref: if password: zip_ref.setpassword(password.encode()) diff --git a/llm_web_kit/model/resource_utils/utils.py b/llm_web_kit/model/resource_utils/utils.py index 63595eee..42dacbf5 100644 --- a/llm_web_kit/model/resource_utils/utils.py +++ b/llm_web_kit/model/resource_utils/utils.py @@ -32,13 +32,6 @@ def decide_cache_dir(): CACHE_DIR, CACHE_TMP_DIR = decide_cache_dir() -if not os.path.exists(CACHE_DIR): - os.makedirs(CACHE_DIR, exist_ok=True) - -if not os.path.exists(CACHE_TMP_DIR): - os.makedirs(CACHE_TMP_DIR, exist_ok=True) - - def try_remove(path: str): """Attempt to remove a file by os.remove or to remove a directory by shutil.rmtree and ignore exceptions.""" @@ -53,6 +46,7 @@ def try_remove(path: str): def import_transformer(): os.environ['HF_HOME'] = CACHE_DIR + os.makedirs(CACHE_DIR, exist_ok=True) import transformers return transformers