Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions llm_web_kit/model/html_classify/model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import torch

from llm_web_kit.model.resource_utils import import_transformer


class Markuplm():
def __init__(self, path, device):
Expand All @@ -16,13 +18,13 @@ def __init__(self, path, device):
self.tokenizer = self.load_tokenizer()

def load_tokenizer(self):
from transformers import MarkupLMProcessor
transformers = import_transformer()

return MarkupLMProcessor.from_pretrained(self.model_path)
return transformers.MarkupLMProcessor.from_pretrained(self.model_path)

def load_model(self):
from transformers import MarkupLMForSequenceClassification
model = MarkupLMForSequenceClassification.from_pretrained(self.model_path, num_labels=self.num_labels)
transformers = import_transformer()
model = transformers.MarkupLMForSequenceClassification.from_pretrained(self.model_path, num_labels=self.num_labels)
# load checkpoint
model.load_state_dict(torch.load(self.checkpoint_path, map_location=self.device))
model.to(self.device)
Expand Down
6 changes: 3 additions & 3 deletions llm_web_kit/model/policical.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from llm_web_kit.libs.logger import mylogger as logger
from llm_web_kit.model.resource_utils import (CACHE_DIR, download_auto_file,
get_unzip_dir,
import_transformer,
singleton_resource_manager,
unzip_local_file)

Expand All @@ -18,16 +19,15 @@ class PoliticalDetector:
def __init__(self, model_path: str = None):
# import AutoTokenizer here to avoid isort error
# must set the HF_HOME to the CACHE_DIR at this point
os.environ['HF_HOME'] = CACHE_DIR
from transformers import AutoTokenizer
transformer = import_transformer()

if not model_path:
model_path = self.auto_download()
model_bin_path = os.path.join(model_path, 'model.bin')
tokenizer_path = os.path.join(model_path, 'internlm2-chat-20b')

self.model = fasttext.load_model(model_bin_path)
self.tokenizer = AutoTokenizer.from_pretrained(
self.tokenizer = transformer.AutoTokenizer.from_pretrained(
tokenizer_path, use_fast=False, trust_remote_code=True
)

Expand Down
12 changes: 8 additions & 4 deletions llm_web_kit/model/porn_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,23 @@
from typing import Dict, List, Union

import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer

from llm_web_kit.config.cfg_reader import load_config
from llm_web_kit.libs.logger import mylogger as logger
from llm_web_kit.model.resource_utils import (CACHE_DIR, download_auto_file,
get_unzip_dir, unzip_local_file)
get_unzip_dir,
import_transformer,
unzip_local_file)

# from transformers import AutoModelForSequenceClassification, AutoTokenizer


class BertModel:
def __init__(self, model_path: str = None) -> None:
if not model_path:
model_path = self.auto_download()
self.model = AutoModelForSequenceClassification.from_pretrained(
transformers_module = import_transformer()
self.model = transformers_module.AutoModelForSequenceClassification.from_pretrained(
os.path.join(model_path, 'porn_classifier/classifier_hf')
)
with open(
Expand All @@ -37,7 +41,7 @@ def __init__(self, model_path: str = None) -> None:
if hasattr(self.model, 'to_bettertransformer'):
self.model = self.model.to_bettertransformer()

self.tokenizer = AutoTokenizer.from_pretrained(
self.tokenizer = transformers_module.AutoTokenizer.from_pretrained(
os.path.join(model_path, 'porn_classifier/classifier_hf')
)
self.tokenizer_config = {
Expand Down
12 changes: 10 additions & 2 deletions llm_web_kit/model/resource_utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,14 @@
from .download_assets import download_auto_file
from .singleton_resource_manager import singleton_resource_manager
from .unzip_ext import get_unzip_dir, unzip_local_file
from .utils import CACHE_DIR, CACHE_TMP_DIR
from .utils import CACHE_DIR, CACHE_TMP_DIR, import_transformer

__all__ = ['download_auto_file', 'unzip_local_file', 'get_unzip_dir', 'CACHE_DIR', 'CACHE_TMP_DIR', 'singleton_resource_manager']
__all__ = [
'download_auto_file',
'unzip_local_file',
'get_unzip_dir',
'CACHE_DIR',
'CACHE_TMP_DIR',
'singleton_resource_manager',
'import_transformer',
]
1 change: 1 addition & 0 deletions llm_web_kit/model/resource_utils/download_assets.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,7 @@ def download_auto_file_core(
progress = tqdm(total=total_size, unit='iB', unit_scale=True)

# 使用临时目录确保原子性
os.makedirs(CACHE_TMP_DIR, exist_ok=True)
with tempfile.TemporaryDirectory(dir=CACHE_TMP_DIR) as temp_dir:
download_path = os.path.join(temp_dir, 'download_file')
try:
Expand Down
3 changes: 3 additions & 0 deletions llm_web_kit/model/resource_utils/unzip_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,9 @@ def unzip_local_file_core(
if os.path.exists(target_dir):
raise ModelResourceException(f'Target directory {target_dir} already exists')

# make sure the parent directory exists
os.makedirs(os.path.dirname(target_dir), exist_ok=True)

with zipfile.ZipFile(zip_path, 'r') as zip_ref:
if password:
zip_ref.setpassword(password.encode())
Expand Down
15 changes: 8 additions & 7 deletions llm_web_kit/model/resource_utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,6 @@ def decide_cache_dir():
CACHE_DIR, CACHE_TMP_DIR = decide_cache_dir()


if not os.path.exists(CACHE_DIR):
os.makedirs(CACHE_DIR, exist_ok=True)

if not os.path.exists(CACHE_TMP_DIR):
os.makedirs(CACHE_TMP_DIR, exist_ok=True)


def try_remove(path: str):
"""Attempt to remove a file by os.remove or to remove a directory by
shutil.rmtree and ignore exceptions."""
Expand All @@ -49,3 +42,11 @@ def try_remove(path: str):
os.remove(path)
except Exception:
pass


def import_transformer():
os.environ['HF_HOME'] = CACHE_DIR
os.makedirs(CACHE_DIR, exist_ok=True)
import transformers

return transformers
22 changes: 15 additions & 7 deletions tests/llm_web_kit/model/test_porn_detector.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,28 @@
import logging
import os
import sys
import unittest
from unittest import TestCase
from unittest.mock import MagicMock, mock_open, patch

from transformers import logging as transformers_logging

from llm_web_kit.model.porn_detector import BertModel # noqa: E402

current_file_path = os.path.abspath(__file__)
parent_dir_path = os.path.join(current_file_path, *[os.pardir] * 4)
normalized_path = os.path.normpath(parent_dir_path)
sys.path.append(normalized_path)

from llm_web_kit.model.porn_detector import BertModel # noqa: E402

transformers_logging.set_verbosity_error()

logging.disable(logging.CRITICAL)


class TestBertModel(TestCase):
@patch('llm_web_kit.model.porn_detector.AutoModelForSequenceClassification.from_pretrained')
@patch('llm_web_kit.model.porn_detector.AutoTokenizer.from_pretrained')
@patch('transformers.AutoModelForSequenceClassification.from_pretrained')
@patch('transformers.AutoTokenizer.from_pretrained')
@patch('llm_web_kit.model.porn_detector.os.path.join')
@patch('llm_web_kit.model.porn_detector.open', new_callable=mock_open, read_data='{"cls_index": 0, "use_sigmoid": true, "max_tokens": 512, "device": "cuda"}')
@patch('llm_web_kit.model.porn_detector.BertModel.auto_download')
Expand Down Expand Up @@ -52,8 +60,8 @@ def test_init(self, mock_auto_download, mock_open, mock_os_path_join, mock_from_
}
)

@patch('llm_web_kit.model.porn_detector.AutoModelForSequenceClassification.from_pretrained')
@patch('llm_web_kit.model.porn_detector.AutoTokenizer.from_pretrained')
@patch('transformers.AutoModelForSequenceClassification.from_pretrained')
@patch('transformers.AutoTokenizer.from_pretrained')
@patch('llm_web_kit.model.porn_detector.os.path.join')
@patch('llm_web_kit.model.porn_detector.open', new_callable=mock_open, read_data='{"cls_index": 0, "use_sigmoid": true, "max_tokens": 512, "device": "cuda"}')
@patch('llm_web_kit.model.porn_detector.BertModel.auto_download')
Expand Down Expand Up @@ -107,8 +115,8 @@ def test_pre_process(self, mock_torch, mock_auto_download, mock_open, mock_os_pa
self.assertEqual(result['inputs']['input_ids'], expected_input_ids)
self.assertEqual(result['inputs']['attention_mask'], expected_attn_mask)

@patch('llm_web_kit.model.porn_detector.AutoModelForSequenceClassification.from_pretrained')
@patch('llm_web_kit.model.porn_detector.AutoTokenizer.from_pretrained')
@patch('transformers.AutoModelForSequenceClassification.from_pretrained')
@patch('transformers.AutoTokenizer.from_pretrained')
@patch('llm_web_kit.model.porn_detector.os.path.join')
@patch('llm_web_kit.model.porn_detector.open', new_callable=mock_open, read_data='{"cls_index": 0, "use_sigmoid": true, "max_tokens": 512, "device": "cuda"}')
@patch('llm_web_kit.model.porn_detector.BertModel.auto_download')
Expand Down