diff --git a/app/config.py b/app/config.py index 854123f..efef7ef 100644 --- a/app/config.py +++ b/app/config.py @@ -34,6 +34,7 @@ class Settings(BaseSettings): # type: ignore TRAINING_METRICS_LOGGING_INTERVAL: int = 5 # the number of steps after which training metrics will be collected TRAINING_SAFE_MODEL_SERIALISATION: str = "false" # if "true", serialise the trained model using safe tensors TRAINING_CACHE_DIR: str = os.path.join(os.path.abspath(os.path.dirname(__file__)), "cms_cache") # the directory to cache the intermediate files created during training + TRAINING_HF_TAGGING_SCHEME: str = "flat" # the tagging scheme during the Hugging Face NER model training, either "flat", "iob" or "iobes" HF_PIPELINE_AGGREGATION_STRATEGY: str = "simple" # the strategy used for aggregating the predictions of the Hugging Face NER model LOG_PER_CONCEPT_ACCURACIES: str = "false" # if "true", per-concept accuracies will be exposed to the metrics scrapper. Switch this on with caution due to the potentially high number of concepts MEDCAT2_MAPPED_ONTOLOGIES: str = "" # the comma-separated names of ontologies for MedCAT2 to map to diff --git a/app/domain.py b/app/domain.py index 6be1564..8362499 100644 --- a/app/domain.py +++ b/app/domain.py @@ -77,6 +77,12 @@ class Device(str, Enum): MPS = "mps" +class TaggingScheme(str, Enum): + FLAT = "flat" + IOB = "iob" + IOBES = "iobes" + + class HfTransformerBackbone(Enum): ALBERT = "albert" BIG_BIRD = "bert" @@ -110,20 +116,24 @@ class LlmEngine(Enum): CMS = "CMS" VLLM = "vLLM" + class LlmRole(Enum): SYSTEM = "system" USER = "user" ASSISTANT = "assistant" TOOL = "tool" + class LlmTrainerType(Enum): GRPO = "grpo" PPO = "ppo" + class LlmDatasetType(Enum): JSON = "json" CSV = "csv" + class Annotation(BaseModel): doc_name: Optional[str] = Field(default=None, description="The name of the document to which the annotation belongs") start: int = Field(description="The start index of the annotation span") diff --git a/app/envs/.env b/app/envs/.env index 3a3a12b..abcbc55 100644 --- a/app/envs/.env +++ b/app/envs/.env @@ -73,6 +73,9 @@ TRAINING_SAFE_MODEL_SERIALISATION=false # The strategy used for aggregating the predictions of the Hugging Face NER model HF_PIPELINE_AGGREGATION_STRATEGY=simple +# The tagging scheme during the Hugging Face NER model training, either "flat", "iob" or "iobes" +TRAINING_HF_TAGGING_SCHEME=flat + # The comma-separated names of ontologies for MedCAT2 to map to MEDCAT2_MAPPED_ONTOLOGIES=opcs4,icd10 diff --git a/app/model_services/huggingface_llm_model.py b/app/model_services/huggingface_llm_model.py index a747739..df128c7 100644 --- a/app/model_services/huggingface_llm_model.py +++ b/app/model_services/huggingface_llm_model.py @@ -16,7 +16,7 @@ from app.exception import ConfigurationException from app.model_services.base import AbstractModelService from app.trainers.huggingface_llm_trainer import HuggingFaceLlmSupervisedTrainer -from app.domain import ModelCard, ModelType, Annotation +from app.domain import ModelCard, ModelType, Annotation, Device from app.config import Settings from app.utils import ( get_settings, @@ -157,9 +157,19 @@ def load_model( bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_use_double_quant=True, ) - model = AutoModelForCausalLM.from_pretrained(model_path, quantization_config=bnb_config) + if get_settings().DEVICE == Device.DEFAULT.value: + model = AutoModelForCausalLM.from_pretrained( + model_path, + quantization_config=bnb_config, + device_map="auto", + ) + else: + model = AutoModelForCausalLM.from_pretrained(model_path, quantization_config=bnb_config) else: - model = AutoModelForCausalLM.from_pretrained(model_path) + if get_settings().DEVICE == Device.DEFAULT.value: + model = AutoModelForCausalLM.from_pretrained(model_path, device_map="auto") + else: + model = AutoModelForCausalLM.from_pretrained(model_path) ensure_tensor_contiguity(model) tokenizer = AutoTokenizer.from_pretrained( model_path, @@ -242,8 +252,7 @@ def generate( self.model.eval() inputs = self.tokenizer(prompt, add_special_tokens=False, return_tensors="pt") - if non_default_device_is_available(self._config.DEVICE): - inputs.to(get_settings().DEVICE) + inputs.to(self.model.device) generation_kwargs = dict( inputs=inputs.input_ids, @@ -291,8 +300,7 @@ async def generate_async( self.model.eval() inputs = self.tokenizer(prompt, add_special_tokens=False, return_tensors="pt") - if non_default_device_is_available(self._config.DEVICE): - inputs.to(get_settings().DEVICE) + inputs.to(self.model.device) streamer = TextIteratorStreamer( self.tokenizer, @@ -363,8 +371,7 @@ def create_embeddings( truncation=True, ) - if non_default_device_is_available(self._config.DEVICE): - inputs.to(get_settings().DEVICE) + inputs.to(self.model.device) with torch.no_grad(): outputs = self.model(**inputs, output_hidden_states=True) diff --git a/app/model_services/huggingface_ner_model.py b/app/model_services/huggingface_ner_model.py index e741705..98e55f9 100644 --- a/app/model_services/huggingface_ner_model.py +++ b/app/model_services/huggingface_ner_model.py @@ -16,7 +16,7 @@ from app.exception import ConfigurationException from app.model_services.base import AbstractModelService from app.trainers.huggingface_ner_trainer import HuggingFaceNerUnsupervisedTrainer, HuggingFaceNerSupervisedTrainer -from app.domain import ModelCard, ModelType, Annotation +from app.domain import ModelCard, ModelType, Annotation, Device, TaggingScheme from app.config import Settings from app.utils import ( get_settings, @@ -27,6 +27,7 @@ get_model_data_package_base_name, load_pydantic_object_from_dict, ) +from app.processors.tagging import TagProcessor logger = logging.getLogger("cms") @@ -41,7 +42,7 @@ def __init__( enable_trainer: Optional[bool] = None, model_name: Optional[str] = None, base_model_file: Optional[str] = None, - confidence_threshold: float = 0.5, + confidence_threshold: float = 0.7, ) -> None: """ Initialises the HuggingFace NER model service with specified configurations. @@ -52,7 +53,7 @@ def __init__( enable_trainer (Optional[bool]): The flag to enable or disable trainers. Defaults to None. model_name (Optional[str]): The name of the model. Defaults to None. base_model_file (Optional[str]): The model package file name. Defaults to None. - confidence_threshold (float): The threshold for the confidence score. Defaults to 0.5. + confidence_threshold (float): The threshold for the confidence score. Defaults to 0.7. """ super().__init__(config) @@ -123,7 +124,8 @@ def from_model(cls, model: PreTrainedModel, tokenizer: PreTrainedTokenizerBase) HuggingFaceNerModel: A HuggingFace NER model service. """ - model_service = cls(get_settings(), enable_trainer=False) + _config = get_settings() + model_service = cls(_config, enable_trainer=False) model_service.model = model model_service.tokenizer = tokenizer _pipeline = partial( @@ -131,11 +133,11 @@ def from_model(cls, model: PreTrainedModel, tokenizer: PreTrainedTokenizerBase) task="ner", model=model_service.model, tokenizer=model_service.tokenizer, - stride=10, - aggregation_strategy=get_settings().HF_PIPELINE_AGGREGATION_STRATEGY, + stride=32, + aggregation_strategy=_config.HF_PIPELINE_AGGREGATION_STRATEGY, ) - if non_default_device_is_available(get_settings().DEVICE): - model_service._ner_pipeline = _pipeline(device=get_hf_pipeline_device_id(get_settings().DEVICE)) + if non_default_device_is_available(_config.DEVICE): + model_service._ner_pipeline = _pipeline(device=get_hf_pipeline_device_id(_config.DEVICE)) else: model_service._ner_pipeline = _pipeline() return model_service @@ -160,7 +162,10 @@ def load_model(model_file_path: str, *args: Tuple, **kwargs: Dict[str, Any]) -> model_path = os.path.join(os.path.dirname(model_file_path), get_model_data_package_base_name(model_file_path)) if unpack_model_data_package(model_file_path, model_path): try: - model = AutoModelForTokenClassification.from_pretrained(model_path) + if get_settings().DEVICE == Device.DEFAULT.value: + model = AutoModelForTokenClassification.from_pretrained(model_path, device_map="auto") + else: + model = AutoModelForTokenClassification.from_pretrained(model_path) ensure_tensor_contiguity(model) tokenizer = AutoTokenizer.from_pretrained( model_path, @@ -197,7 +202,7 @@ def init_model(self, *args: Any, **kwargs: Any) -> None: task="ner", model=self._model, tokenizer=self._tokenizer, - stride=10, + stride=32, aggregation_strategy=self._config.HF_PIPELINE_AGGREGATION_STRATEGY, ) if non_default_device_is_available(get_settings().DEVICE): @@ -233,12 +238,29 @@ def annotate(self, text: str) -> List[Annotation]: List[Annotation]: A list of annotations containing the extracted named entities. """ - entities = self._ner_pipeline(text) + if TaggingScheme(self._config.TRAINING_HF_TAGGING_SCHEME.lower()) == TaggingScheme.IOBES: + entities = self._ner_pipeline(text, aggregation_strategy="none") + else: + entities = self._ner_pipeline(text) df = pd.DataFrame(entities) if df.empty: columns = ["label_name", "label_id", "start", "end", "accuracy"] df = pd.DataFrame(columns=(columns + ["text"]) if self._config.INCLUDE_SPAN_TEXT == "true" else columns) + elif TaggingScheme(self._config.TRAINING_HF_TAGGING_SCHEME.lower()) == TaggingScheme.IOBES: + aggregated_entities = TagProcessor.aggregate_bioes_predictions( + df, + text, + self._config.INCLUDE_SPAN_TEXT == "true", + ) + df = pd.DataFrame(aggregated_entities) + if df.empty: + columns = ["label_name", "label_id", "start", "end", "accuracy"] + df = pd.DataFrame( + columns=(columns + ["text"]) if self._config.INCLUDE_SPAN_TEXT == "true" else columns + ) + else: + df = df[df["accuracy"] >= self._confidence_threshold] else: for idx, row in df.iterrows(): df.loc[idx, "label_id"] = row["entity_group"] diff --git a/app/processors/tagging.py b/app/processors/tagging.py new file mode 100644 index 0000000..3b45bc4 --- /dev/null +++ b/app/processors/tagging.py @@ -0,0 +1,402 @@ +import pandas as pd +from typing import Any, Dict, List, Optional, Iterable +from torch import nn, mean, cat +from transformers import PreTrainedModel +from app.domain import TaggingScheme + + +class TagProcessor: + + @staticmethod + def update_model_by_tagging_scheme( + model: PreTrainedModel, + concepts: List[str], + tagging_scheme: TaggingScheme, + ) -> PreTrainedModel: + avg_weight = mean(model.classifier.weight, dim=0, keepdim=True) + avg_bias = mean(model.classifier.bias, dim=0, keepdim=True) + if tagging_scheme == TaggingScheme.IOB: + for concept in concepts: + b_label = f"B-{concept}" + i_label = f"I-{concept}" + if b_label not in model.config.label2id.keys(): + model.config.label2id[b_label] = len(model.config.label2id) + model.config.id2label[len(model.config.id2label)] = b_label + model.classifier.weight = nn.Parameter(cat((model.classifier.weight, avg_weight), 0)) + model.classifier.bias = nn.Parameter(cat((model.classifier.bias, avg_bias), 0)) + model.classifier.out_features += 1 + model.num_labels += 1 + if i_label not in model.config.label2id.keys(): + model.config.label2id[i_label] = len(model.config.label2id) + model.config.id2label[len(model.config.id2label)] = i_label + model.classifier.weight = nn.Parameter(cat((model.classifier.weight, avg_weight), 0)) + model.classifier.bias = nn.Parameter(cat((model.classifier.bias, avg_bias), 0)) + model.classifier.out_features += 1 + model.num_labels += 1 + elif tagging_scheme == TaggingScheme.IOBES: + for concept in concepts: + s_label = f"S-{concept}" + b_label = f"B-{concept}" + i_label = f"I-{concept}" + e_label = f"E-{concept}" + if s_label not in model.config.label2id.keys(): + model.config.label2id[s_label] = len(model.config.label2id) + model.config.id2label[len(model.config.id2label)] = s_label + model.classifier.weight = nn.Parameter(cat((model.classifier.weight, avg_weight), 0)) + model.classifier.bias = nn.Parameter(cat((model.classifier.bias, avg_bias), 0)) + model.classifier.out_features += 1 + model.num_labels += 1 + if b_label not in model.config.label2id.keys(): + model.config.label2id[b_label] = len(model.config.label2id) + model.config.id2label[len(model.config.id2label)] = b_label + model.classifier.weight = nn.Parameter(cat((model.classifier.weight, avg_weight), 0)) + model.classifier.bias = nn.Parameter(cat((model.classifier.bias, avg_bias), 0)) + model.classifier.out_features += 1 + model.num_labels += 1 + if i_label not in model.config.label2id.keys(): + model.config.label2id[i_label] = len(model.config.label2id) + model.config.id2label[len(model.config.id2label)] = i_label + model.classifier.weight = nn.Parameter(cat((model.classifier.weight, avg_weight), 0)) + model.classifier.bias = nn.Parameter(cat((model.classifier.bias, avg_bias), 0)) + model.classifier.out_features += 1 + model.num_labels += 1 + if e_label not in model.config.label2id.keys(): + model.config.label2id[e_label] = len(model.config.label2id) + model.config.id2label[len(model.config.id2label)] = e_label + model.classifier.weight = nn.Parameter(cat((model.classifier.weight, avg_weight), 0)) + model.classifier.bias = nn.Parameter(cat((model.classifier.bias, avg_bias), 0)) + model.classifier.out_features += 1 + model.num_labels += 1 + else: + for concept in concepts: + if concept not in model.config.label2id.keys(): + model.config.label2id[concept] = len(model.config.label2id) + model.config.id2label[len(model.config.id2label)] = concept + model.classifier.weight = nn.Parameter(cat((model.classifier.weight, avg_weight), 0)) + model.classifier.bias = nn.Parameter(cat((model.classifier.bias, avg_bias), 0)) + model.classifier.out_features += 1 + model.num_labels += 1 + return model + + @staticmethod + def generate_chuncks_by_tagging_scheme( + annotations: List[Dict], + tokenized: Dict[str, List], + delfault_label_id: int, + pad_token_id: int, + pad_label_id: int, + max_length: int, + model: PreTrainedModel, + tagging_scheme: TaggingScheme, + window_size: int, + stride: int, + ) -> Iterable[Dict[str, Any]]: + if tagging_scheme == TaggingScheme.IOB: + labels = [delfault_label_id] * len(tokenized["input_ids"]) + for annotation in annotations: + start = annotation["start"] + end = annotation["end"] + cui = annotation["cui"] + b_label = f"B-{cui}" + i_label = f"I-{cui}" + b_label_id = model.config.label2id.get(b_label, delfault_label_id) + i_label_id = model.config.label2id.get(i_label, delfault_label_id) + first_token = True + for idx, offset_mapping in enumerate(tokenized["offset_mapping"]): + if start <= offset_mapping[0] and offset_mapping[1] <= end: + if first_token: + labels[idx] = b_label_id + first_token = False + else: + labels[idx] = i_label_id + + for start in range(0, len(tokenized["input_ids"]), stride): + end = min(start + window_size, len(tokenized["input_ids"])) + chunked_input_ids = tokenized["input_ids"][start:end] + chunked_labels = labels[start:end] + chunked_attention_mask = tokenized["attention_mask"][start:end] + padding_length = max(0, max_length - len(chunked_input_ids)) + chunked_input_ids += [pad_token_id] * padding_length + chunked_labels += [pad_label_id] * padding_length + chunked_attention_mask += [0] * padding_length + + yield { + "input_ids": chunked_input_ids, + "labels": chunked_labels, + "attention_mask": chunked_attention_mask, + } + + elif tagging_scheme == TaggingScheme.IOBES: + labels = [delfault_label_id] * len(tokenized["input_ids"]) + for annotation in annotations: + ann_start = annotation["start"] + ann_end = annotation["end"] + cui = annotation["cui"] + + covered_indices = [ + idx for idx, off in enumerate(tokenized["offset_mapping"]) + if ann_start <= off[0] and off[1] <= ann_end + ] + if not covered_indices: + continue + + if len(covered_indices) == 1: + s_label = f"S-{cui}" + s_id = model.config.label2id.get(s_label, delfault_label_id) + labels[covered_indices[0]] = s_id + else: + b_label = f"B-{cui}" + i_label = f"I-{cui}" + e_label = f"E-{cui}" + b_id = model.config.label2id.get(b_label, delfault_label_id) + i_id = model.config.label2id.get(i_label, delfault_label_id) + e_id = model.config.label2id.get(e_label, delfault_label_id) + + labels[covered_indices[0]] = b_id + for mid_idx in covered_indices[1:-1]: + labels[mid_idx] = i_id + labels[covered_indices[-1]] = e_id + + for start in range(0, len(tokenized["input_ids"]), stride): + end = min(start + window_size, len(tokenized["input_ids"])) + chunked_input_ids = tokenized["input_ids"][start:end] + chunked_labels = labels[start:end] + chunked_attention_mask = tokenized["attention_mask"][start:end] + padding_length = max(0, max_length - len(chunked_input_ids)) + chunked_input_ids += [pad_token_id] * padding_length + chunked_labels += [pad_label_id] * padding_length + chunked_attention_mask += [0] * padding_length + + yield { + "input_ids": chunked_input_ids, + "labels": chunked_labels, + "attention_mask": chunked_attention_mask, + } + else: + for start in range(0, len(tokenized["input_ids"]), stride): + end = min(start + window_size, len(tokenized["input_ids"])) + chunked_input_ids = tokenized["input_ids"][start:end] + chunked_offsets_mapping = tokenized["offset_mapping"][start:end] + chunked_labels = [0] * len(chunked_input_ids) + chunked_attention_mask = tokenized["attention_mask"][start:end] + for annotation in annotations: + annotation_start = annotation["start"] + annotation_end = annotation["end"] + label_id = model.config.label2id.get(annotation["cui"], delfault_label_id) + for idx, offset_mapping in enumerate(chunked_offsets_mapping): + if annotation_start <= offset_mapping[0] and offset_mapping[1] <= annotation_end: + chunked_labels[idx] = label_id + padding_length = max(0, max_length - len(chunked_input_ids)) + chunked_input_ids += [pad_token_id] * padding_length + chunked_labels += [pad_label_id] * padding_length + chunked_attention_mask += [0] * padding_length + + yield { + "input_ids": chunked_input_ids, + "labels": chunked_labels, + "attention_mask": chunked_attention_mask, + } + + @staticmethod + def aggregate_bioes_predictions( + df: pd.DataFrame, + text: str, + include_span_text: bool = False, + ) -> List[Dict[str, Any]]: + aggregated_entities = [] + current_entity = None + current_label = None + current_score = 0.0 + token_count = 0 + + for _, row in df.iterrows(): + entity_tag = str(row.get("entity", "")).strip() + score = float(row.get("score", 0.0)) + start = int(row.get("start", 0)) + end = int(row.get("end", 0)) + + if entity_tag.upper() == "O" or entity_tag == "": + if current_entity is not None: + aggregated_entities.append( + TagProcessor._get_composed_entitiy( + text, + current_entity, + current_label, + current_score, + token_count, + include_span_text, + ) + ) + current_entity = None + current_label = None + current_score = 0.0 + token_count = 0 + continue + + if "-" in entity_tag: + prefix, label = entity_tag.split("-", 1) + prefix = prefix.upper() + else: + prefix = None + label = entity_tag + + if prefix == "B": + if current_entity is not None: + aggregated_entities.append( + TagProcessor._get_composed_entitiy( + text, + current_entity, + current_label, + current_score, + token_count, + include_span_text, + ) + ) + current_label = label + current_entity = {"start": start, "end": end} + current_score = score + token_count = 1 + + elif prefix == "I": + if current_entity is None: + current_label = label + current_entity = {"start": start, "end": end} + current_score = score + token_count = 1 + else: + if label == current_label: + current_entity["end"] = end + current_score += score + token_count += 1 + else: + aggregated_entities.append( + TagProcessor._get_composed_entitiy( + text, + current_entity, + current_label, + current_score, + token_count, + include_span_text, + ) + ) + current_label = label + current_entity = {"start": start, "end": end} + current_score = score + token_count = 1 + + elif prefix == "E": + if current_entity is None: + single_ent = {"start": start, "end": end} + aggregated_entities.append( + TagProcessor._get_composed_entitiy(text, single_ent, label, score, 1, include_span_text) + ) + else: + if label == current_label: + current_entity["end"] = end + current_score += score + token_count += 1 + aggregated_entities.append( + TagProcessor._get_composed_entitiy( + text, + current_entity, + current_label, + current_score, + token_count, + include_span_text, + ) + ) + current_entity = None + current_label = None + current_score = 0.0 + token_count = 0 + else: + aggregated_entities.append( + TagProcessor._get_composed_entitiy( + text, + current_entity, + current_label, + current_score, + token_count, + include_span_text, + ) + ) + single_ent = {"start": start, "end": end} + aggregated_entities.append( + TagProcessor._get_composed_entitiy(text, single_ent, label, score, 1, include_span_text) + ) + current_entity = None + current_label = None + current_score = 0.0 + token_count = 0 + + elif prefix == "S" or prefix is None: + if current_entity is not None: + aggregated_entities.append( + TagProcessor._get_composed_entitiy( + text, + current_entity, + current_label, + current_score, + token_count, + include_span_text, + ) + ) + current_entity = None + current_label = None + current_score = 0.0 + token_count = 0 + single_ent = {"start": start, "end": end} + aggregated_entities.append( + TagProcessor._get_composed_entitiy(text, single_ent, label, score, 1, include_span_text) + ) + + else: + if current_entity is not None: + aggregated_entities.append( + TagProcessor._get_composed_entitiy( + text, + current_entity, + current_label, + current_score, + token_count, + include_span_text, + ) + ) + current_entity = None + current_label = None + current_score = 0.0 + token_count = 0 + + if current_entity is not None: + aggregated_entities.append( + TagProcessor._get_composed_entitiy( + text, + current_entity, + current_label, + current_score, + token_count, + include_span_text, + ) + ) + + return aggregated_entities + + @staticmethod + def _get_composed_entitiy( + text: str, + entity: Dict, + label: Optional[str], + score: float, + token_count: int, + include_span_text: bool, + ) -> Dict[str, Any]: + return { + "entity_group": label, + "label_name": label, + "label_id": label, + "start": entity["start"], + "end": entity["end"], + "score": score / token_count, + "accuracy": score / token_count, + "text": text[entity["start"]:entity["end"]] if include_span_text else None + } diff --git a/app/trainers/huggingface_ner_trainer.py b/app/trainers/huggingface_ner_trainer.py index c975506..2aa44aa 100644 --- a/app/trainers/huggingface_ner_trainer.py +++ b/app/trainers/huggingface_ner_trainer.py @@ -14,7 +14,9 @@ from typing import final, Dict, TextIO, Optional, Any, List, Iterable, Tuple, Union, cast, TYPE_CHECKING from torch import nn from tqdm import tqdm -from sklearn.metrics import precision_recall_fscore_support, accuracy_score +from sklearn.metrics import precision_recall_fscore_support, accuracy_score as sklearn_accuracy_score +from sklearn.utils.class_weight import compute_class_weight +from seqeval.metrics import classification_report, accuracy_score as seqeval_accuracy_score from scipy.special import softmax from transformers import __version__ as transformers_version from transformers import ( @@ -45,7 +47,8 @@ get_model_data_package_base_name, ) from app.trainers.base import UnsupervisedTrainer, SupervisedTrainer -from app.domain import ModelType, DatasetSplit, HfTransformerBackbone, Device, TrainerBackend +from app.domain import ModelType, DatasetSplit, HfTransformerBackbone, Device, TrainerBackend, TaggingScheme +from app.processors.tagging import TagProcessor from app.exception import AnnotationException, TrainingCancelledException, DatasetException if TYPE_CHECKING: from app.model_services.huggingface_ner_model import HuggingFaceNerModel @@ -121,6 +124,8 @@ def run( reset_random_seed() eval_mode = training_params["nepochs"] == 0 + window_size = max(self._max_length - 2, 1) + stride = max(window_size // 2, 1) self._tracker_client.log_trainer_mode(not eval_mode) if not eval_mode: try: @@ -131,7 +136,7 @@ def run( os.path.dirname(copied_model_pack_path), get_model_data_package_base_name(copied_model_pack_path), ) - mlm_model = self._get_mlm_model(model, copied_model_directory) + mlm_model = self._get_mlm_model(model, copied_model_directory, self._config.DEVICE) if non_default_device_is_available(self._config.DEVICE): mlm_model.to(self._config.DEVICE) @@ -165,13 +170,25 @@ def run( train_dataset = datasets.Dataset.from_generator( self._tokenize_and_chunk, features=dataset_features, - gen_kwargs={"texts": train_texts, "tokenizer": tokenizer, "max_length": self._max_length}, + gen_kwargs={ + "texts": train_texts, + "tokenizer": tokenizer, + "max_length": self._max_length, + "window_size": window_size, + "stride": stride, + }, cache_dir=self._model_service._config.TRAINING_CACHE_DIR, ) eval_dataset = datasets.Dataset.from_generator( self._tokenize_and_chunk, features=dataset_features, - gen_kwargs={"texts": eval_texts, "tokenizer": tokenizer, "max_length": self._max_length}, + gen_kwargs={ + "texts": eval_texts, + "tokenizer": tokenizer, + "max_length": self._max_length, + "window_size": window_size, + "stride": stride, + }, cache_dir = self._model_service._config.TRAINING_CACHE_DIR, ) train_dataset.set_format(type="torch", columns=["input_ids", "attention_mask"]) @@ -279,7 +296,11 @@ def run( self._tracker_client.log_model_config(self._model_service._model.config.to_dict()) self._tracker_client.log_trainer_version(TrainerBackend.TRANSFORMERS, transformers_version) - mlm_model = self._get_mlm_model(self._model_service._model, os.path.splitext(self._model_pack_path)[0]) + mlm_model = self._get_mlm_model( + self._model_service._model, + os.path.splitext(self._model_pack_path)[0], + self._config.DEVICE, + ) if non_default_device_is_available(self._config.DEVICE): mlm_model.to(self._config.DEVICE) @@ -315,8 +336,6 @@ def run( self._model_service._model.eval() - window_size = 256 - stride = 128 batch_size = 32 def _create_iterative_masking(input_id: List[int], mask_token: int, pad_token_id: int) -> Tuple[torch.Tensor, torch.Tensor]: @@ -345,8 +364,8 @@ def _create_iterative_masking(input_id: List[int], mask_token: int, pad_token_id for input_ids in batch_input_ids: input_ids = torch.tensor(input_ids, dtype=torch.long).unsqueeze(0) # Ensure 2D shape - n = input_ids.size(-1) # 512 - n_windows = (n - 1) // stride + 1 # 4 + n = input_ids.size(-1) + n_windows = (n - 1) // stride + 1 input_ids_out = torch.full( (n_windows, window_size), @@ -429,8 +448,11 @@ def _create_iterative_masking(input_id: List[int], mask_token: int, pad_token_id @staticmethod - def _get_mlm_model(model: PreTrainedModel, copied_model_directory: str) -> PreTrainedModel: - mlm_model = AutoModelForMaskedLM.from_pretrained(copied_model_directory) + def _get_mlm_model(model: PreTrainedModel, copied_model_directory: str, device: str) -> PreTrainedModel: + if device.lower() == Device.DEFAULT.value: + mlm_model = AutoModelForMaskedLM.from_pretrained(copied_model_directory, device_map="auto") + else: + mlm_model = AutoModelForMaskedLM.from_pretrained(copied_model_directory) ensure_tensor_contiguity(mlm_model) backbone_found = False for backbone in HfTransformerBackbone: @@ -459,6 +481,8 @@ def _tokenize_and_chunk( texts: Iterable[str], tokenizer: PreTrainedTokenizerBase, max_length: int, + window_size: int, + stride: int, ) -> Iterable[Dict[str, Any]]: for text in texts: encoded = tokenizer( @@ -468,14 +492,15 @@ def _tokenize_and_chunk( return_special_tokens_mask=True, ) - for i in range(0, len(encoded["input_ids"]), max_length): - chunked_input_ids = encoded["input_ids"][i:i + max_length] + for start in range(0, len(encoded["input_ids"]), stride): + end = min(start + window_size, len(encoded["input_ids"])) + chunked_input_ids = encoded["input_ids"][start:end] padding_length = max(0, max_length - len(chunked_input_ids)) chunked_input_ids += [tokenizer.pad_token_id] * padding_length - chunked_attention_mask = encoded["attention_mask"][i:i + max_length] + [0] * padding_length + chunked_attention_mask = encoded["attention_mask"][start:end] + [0] * padding_length chunked_special_tokens = tokenizer.get_special_tokens_mask(chunked_input_ids, - already_has_special_tokens=True) + already_has_special_tokens=True) token_type_ids = [0] * len(chunked_input_ids) yield { @@ -559,7 +584,10 @@ def run( logs_path = os.path.abspath(os.path.join(self._config.TRAINING_CACHE_DIR, "logs")) reset_random_seed() eval_mode = training_params["nepochs"] == 0 + window_size = max(self._max_length - 2, 1) + stride = max(window_size // 2, 1) self._tracker_client.log_trainer_mode(not eval_mode) + tagging_scheme = TaggingScheme(self._model_service._config.TRAINING_HF_TAGGING_SCHEME.lower()) if not eval_mode: try: logger.info("Loading a new model copy for training...") @@ -575,8 +603,7 @@ def run( filtered_training_data, filtered_concepts = self._filter_training_data_and_concepts(data_file) logger.debug(f"Filtered concepts: {filtered_concepts}") - model = self._update_model_with_concepts(model, filtered_concepts) - + model = self._update_model_with_concepts(model, filtered_concepts, tagging_scheme) test_size = 0.2 if training_params.get("test_size") is None else training_params["test_size"] if test_size < 0: @@ -598,16 +625,33 @@ def run( "labels": datasets.Sequence(datasets.Value("int32")), "attention_mask": datasets.Sequence(datasets.Value("int32")), }) + train_dataset = datasets.Dataset.from_generator( self._tokenize_and_chunk, features=dataset_features, - gen_kwargs={"documents": train_documents, "tokenizer": tokenizer, "max_length": self._max_length, "model": model}, + gen_kwargs={ + "documents": train_documents, + "tokenizer": tokenizer, + "max_length": self._max_length, + "model": model, + "tagging_scheme": tagging_scheme, + "window_size": window_size, + "stride": stride, + }, cache_dir=self._config.TRAINING_CACHE_DIR, ) eval_dataset = datasets.Dataset.from_generator( self._tokenize_and_chunk, features=dataset_features, - gen_kwargs={"documents": eval_documents, "tokenizer": tokenizer, "max_length": self._max_length, "model": model}, + gen_kwargs={ + "documents": eval_documents, + "tokenizer": tokenizer, + "max_length": self._max_length, + "model": model, + "tagging_scheme": tagging_scheme, + "window_size": window_size, + "stride": stride, + }, cache_dir = self._config.TRAINING_CACHE_DIR, ) train_dataset.set_format(type=None, columns=["input_ids", "labels", "attention_mask"]) @@ -625,13 +669,44 @@ def run( if early_stopping_patience > 0: trainer_callbacks.append(EarlyStoppingCallback(early_stopping_patience=early_stopping_patience)) + train_labels = [] + weights = torch.ones(model.num_labels, dtype=torch.float) + for example in train_dataset: + train_labels.extend([label for label in example["labels"] if label != HuggingFaceNerSupervisedTrainer.PAD_LABEL_ID]) + unique_labels = np.unique(train_labels) + class_weight_vect = compute_class_weight("balanced", classes=unique_labels, y=train_labels) + for label_id, weight in zip(unique_labels, class_weight_vect): + weights[label_id] = weight + + if non_default_device_is_available(self._config.DEVICE): + weights = weights.to(self._config.DEVICE) + else: + weights = weights.to(model.device) + + def _compute_loss( + outputs: Dict[str, Any], + labels: torch.Tensor, + num_items_in_batch: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + logits = outputs.get("logits") + loss_func = nn.CrossEntropyLoss(weight=weights, ignore_index=HuggingFaceNerSupervisedTrainer.PAD_LABEL_ID) + loss = loss_func(logits.view(-1, model.num_labels), labels.view(-1)) # type: ignore + return loss + hf_trainer = Trainer( model=model, args=training_args, data_collator=data_collator, train_dataset=train_dataset, eval_dataset=eval_dataset, - compute_metrics=partial(self._compute_token_level_metrics, id2label=model.config.id2label, tracker_client=self._tracker_client, model_name=self._model_name), + compute_metrics=partial( + self._compute_metrics, + id2label=model.config.id2label, + tracker_client=self._tracker_client, + model_name=self._model_name, + token_level=True if tagging_scheme == TaggingScheme.FLAT else False, + ), + compute_loss_func=_compute_loss, callbacks=trainer_callbacks, ) @@ -709,19 +784,34 @@ def run( eval_dataset = datasets.Dataset.from_generator( self._tokenize_and_chunk, features=dataset_features, - gen_kwargs={"documents": eval_documents, "tokenizer": self._model_service.tokenizer, "max_length": self._max_length, "model": self._model_service._model}, + gen_kwargs={ + "documents": eval_documents, + "tokenizer": self._model_service.tokenizer, + "max_length": self._max_length, + "model": self._model_service._model, + "tagging_scheme": tagging_scheme, + "window_size": window_size, + "stride": stride, + }, cache_dir=self._config.TRAINING_CACHE_DIR, ) eval_dataset.set_format(type=None, columns=["input_ids", "labels", "attention_mask"]) data_collator = self._LocalDataCollator(max_length=self._max_length, pad_token_id=self._model_service.tokenizer.pad_token_id) training_args = self._get_training_args(results_path, logs_path, training_params, log_frequency) + training_args.eval_strategy = "no" hf_trainer = Trainer( model=self._model_service.model, args=training_args, data_collator=data_collator, train_dataset=None, eval_dataset=None, - compute_metrics=partial(self._compute_token_level_metrics, id2label=self._model_service.model.config.id2label, tracker_client=self._tracker_client, model_name=self._model_name), + compute_metrics=partial( + self._compute_metrics, + id2label=self._model_service.model.config.id2label, + tracker_client=self._tracker_client, + model_name=self._model_name, + token_level=False, + ), tokenizer=None, ) eval_metrics = hf_trainer.evaluate(eval_dataset) @@ -756,22 +846,16 @@ def _filter_training_data_and_concepts(data_file: TextIO) -> Tuple[Dict, List]: return filtered_training_data, filtered_concepts @staticmethod - def _update_model_with_concepts(model: PreTrainedModel, concepts: List[str]) -> PreTrainedModel: + def _update_model_with_concepts( + model: PreTrainedModel, + concepts: List[str], + tagging_scheme: TaggingScheme, + ) -> PreTrainedModel: if model.config.label2id == {"LABEL_0": 0, "LABEL_1": 1}: logger.debug("Cannot find existing labels and IDs, creating new ones...") model.config.label2id = {"O": HuggingFaceNerSupervisedTrainer.DEFAULT_LABEL_ID, "X": HuggingFaceNerSupervisedTrainer.CONTINUING_TOKEN_LABEL_ID} model.config.id2label = {HuggingFaceNerSupervisedTrainer.DEFAULT_LABEL_ID: "O", HuggingFaceNerSupervisedTrainer.CONTINUING_TOKEN_LABEL_ID: "X"} - avg_weight = torch.mean(model.classifier.weight, dim=0, keepdim=True) - avg_bias = torch.mean(model.classifier.bias, dim=0, keepdim=True) - for concept in concepts: - if concept not in model.config.label2id.keys(): - model.config.label2id[concept] = len(model.config.label2id) - model.config.id2label[len(model.config.id2label)] = concept - model.classifier.weight = nn.Parameter(torch.cat((model.classifier.weight, avg_weight), 0)) - model.classifier.bias = nn.Parameter(torch.cat((model.classifier.bias, avg_bias), 0)) - model.classifier.out_features += 1 - model.num_labels += 1 - return model + return TagProcessor.update_model_by_tagging_scheme(model, concepts, tagging_scheme) @staticmethod def _tokenize_and_chunk( @@ -779,6 +863,9 @@ def _tokenize_and_chunk( tokenizer: PreTrainedTokenizerBase, max_length: int, model: PreTrainedModel, + tagging_scheme: TaggingScheme, + window_size: int, + stride: int, ) -> Iterable[Dict[str, Any]]: for document in documents: encoded = tokenizer( @@ -788,31 +875,27 @@ def _tokenize_and_chunk( return_offsets_mapping=True, ) document["annotations"] = sorted(document["annotations"], key=lambda annotation: annotation["start"]) - for i in range(0, len(encoded["input_ids"]), max_length): - chunked_input_ids = encoded["input_ids"][i:i + max_length] - chunked_offsets_mapping = encoded["offset_mapping"][i:i + max_length] - chunked_labels = [0] * len(chunked_input_ids) - for annotation in document["annotations"]: - start = annotation["start"] - end = annotation["end"] - label_id = model.config.label2id.get(annotation["cui"], HuggingFaceNerSupervisedTrainer.DEFAULT_LABEL_ID) - for idx, offset_mapping in enumerate(chunked_offsets_mapping): - if start <= offset_mapping[0] and offset_mapping[1] <= end: - chunked_labels[idx] = label_id - chunked_attention_mask = encoded["attention_mask"][i:i + max_length] - yield { - "input_ids": chunked_input_ids, - "labels": chunked_labels, - "attention_mask": chunked_attention_mask, - } + yield from TagProcessor.generate_chuncks_by_tagging_scheme( + annotations=document["annotations"], + tokenized=encoded, + delfault_label_id=HuggingFaceNerSupervisedTrainer.DEFAULT_LABEL_ID, + pad_token_id=tokenizer.pad_token_id, + pad_label_id=HuggingFaceNerSupervisedTrainer.PAD_LABEL_ID, + max_length=max_length, + model=model, + tagging_scheme=tagging_scheme, + window_size=window_size, + stride=stride, + ) @staticmethod - def _compute_token_level_metrics( + def _compute_metrics( eval_pred: EvalPrediction, id2label: Dict[int, str], tracker_client: TrackerClient, model_name: str, + token_level: bool, ) -> Dict[str, Any]: predictions = np.argmax(softmax(eval_pred.predictions, axis=2), axis=2) label_ids = eval_pred.label_ids @@ -820,40 +903,91 @@ def _compute_token_level_metrics( non_padding_predictions = predictions[non_padding_indices].flatten() non_padding_label_ids = label_ids[non_padding_indices].flatten() labels = list(id2label.values()) - precision, recall, f1, support = precision_recall_fscore_support(non_padding_label_ids, non_padding_predictions, labels=list(id2label.keys()), average=None) - filtered_predictions, filtered_label_ids = zip(*[(a, b) for a, b in zip(non_padding_predictions, non_padding_label_ids) if not (a == b == HuggingFaceNerSupervisedTrainer.DEFAULT_LABEL_ID)]) - accuracy = accuracy_score(filtered_label_ids, filtered_predictions) - metrics = { - "accuracy": accuracy, - "f1_avg": np.average(f1[2:]), - "precision_avg": np.average(precision[2:]), - "recall_avg": np.average(recall[2:]), - "support_avg": np.average(support[2:]), - } - aggregated_labels = [] - aggregated_metrics = [] - - # limit the number of labels to avoid excessive metrics logging - for idx, (label, precision, recall, f1, support) in enumerate(zip(labels[2:HuggingFaceNerSupervisedTrainer.MAX_CONCEPTS_TO_TRACK+2], - precision[2:HuggingFaceNerSupervisedTrainer.MAX_CONCEPTS_TO_TRACK+2], - recall[2:HuggingFaceNerSupervisedTrainer.MAX_CONCEPTS_TO_TRACK+2], - f1[2:HuggingFaceNerSupervisedTrainer.MAX_CONCEPTS_TO_TRACK+2], - support[2:HuggingFaceNerSupervisedTrainer.MAX_CONCEPTS_TO_TRACK+2])): - if support == 0: # the concept has no true labels - continue - metrics[f"{label}/precision"] = precision if precision is not None else 0.0 - metrics[f"{label}/recall"] = recall if recall is not None else 0.0 - metrics[f"{label}/f1"] = f1 if f1 is not None else 0.0 - metrics[f"{label}/support"] = support if support is not None else 0.0 - - aggregated_labels.append(label) - aggregated_metrics.append({ - "per_concept_p": metrics[f"{label}/precision"], - "per_concept_r": metrics[f"{label}/recall"], - "per_concept_f1": metrics[f"{label}/f1"], - }) - - HuggingFaceNerSupervisedTrainer._save_metrics_plot(aggregated_metrics, aggregated_labels, tracker_client, model_name) + + if token_level: + # Get token level metrics + precision, recall, f1, support = precision_recall_fscore_support(non_padding_label_ids, non_padding_predictions, labels=list(id2label.keys()), average=None) + filtered_predictions, filtered_label_ids = zip(*[(a, b) for a, b in zip(non_padding_predictions, non_padding_label_ids) if not (a == b == HuggingFaceNerSupervisedTrainer.DEFAULT_LABEL_ID)]) + accuracy = sklearn_accuracy_score(filtered_label_ids, filtered_predictions) + metrics = { + "accuracy": accuracy, + "f1_avg": np.average(f1[2:]), + "precision_avg": np.average(precision[2:]), + "recall_avg": np.average(recall[2:]), + "support_avg": np.average(support[2:]), + } + aggregated_labels = [] + aggregated_metrics = [] + + # Limit the number of labels to avoid excessive metrics logging + for idx, (label, precision, recall, f1, support) in enumerate(zip(labels[2:HuggingFaceNerSupervisedTrainer.MAX_CONCEPTS_TO_TRACK+2], + precision[2:HuggingFaceNerSupervisedTrainer.MAX_CONCEPTS_TO_TRACK+2], + recall[2:HuggingFaceNerSupervisedTrainer.MAX_CONCEPTS_TO_TRACK+2], + f1[2:HuggingFaceNerSupervisedTrainer.MAX_CONCEPTS_TO_TRACK+2], + support[2:HuggingFaceNerSupervisedTrainer.MAX_CONCEPTS_TO_TRACK+2])): + if support == 0: # The concept has no true labels + continue + metrics[f"{label}/precision"] = precision if precision is not None else 0.0 + metrics[f"{label}/recall"] = recall if recall is not None else 0.0 + metrics[f"{label}/f1"] = f1 if f1 is not None else 0.0 + metrics[f"{label}/support"] = support if support is not None else 0.0 + + aggregated_labels.append(label) + aggregated_metrics.append({ + "per_concept_p": metrics[f"{label}/precision"], + "per_concept_r": metrics[f"{label}/recall"], + "per_concept_f1": metrics[f"{label}/f1"], + }) + else: + # Get entity level metrics + y_true = [] + y_pred = [] + for i in range(label_ids.shape[0]): + true_labels = [] + pred_labels = [] + for j in range(label_ids.shape[1]): + if label_ids[i, j] != HuggingFaceNerSupervisedTrainer.PAD_LABEL_ID: + true_labels.append(id2label[label_ids[i, j]]) + pred_labels.append(id2label[predictions[i, j]]) + else: + break + y_true.append(true_labels) + y_pred.append(pred_labels) + report = classification_report(y_true, y_pred, output_dict=True) + accuracy = seqeval_accuracy_score(y_true, y_pred) + metrics = { + "accuracy": accuracy, + "f1_avg": np.mean([report[label]["f1-score"] for label in report]), + "precision_avg": np.mean([report[label]["precision"] for label in report]), + "recall_avg": np.mean([report[label]["recall"] for label in report]), + "support_avg": np.mean([report[label]["support"] for label in report]), + } + aggregated_labels = [] + aggregated_metrics = [] + + # Limit the number of labels to avoid excessive metrics logging + label_keys = [k for k in report.keys() if k not in ['weighted avg', 'macro avg', 'micro avg']] + for _, label in enumerate(label_keys[:HuggingFaceNerSupervisedTrainer.MAX_CONCEPTS_TO_TRACK]): + if label not in report or report[label]['support'] == 0: # The label has no true labels + continue + metrics[f"{label}/precision"] = report[label]["precision"] + metrics[f"{label}/recall"] = report[label]["recall"] + metrics[f"{label}/f1"] = report[label]["f1-score"] + metrics[f"{label}/support"] = report[label]["support"] + + aggregated_labels.append(label) + aggregated_metrics.append({ + "per_concept_p": metrics[f"{label}/precision"], + "per_concept_r": metrics[f"{label}/recall"], + "per_concept_f1": metrics[f"{label}/f1"], + }) + + HuggingFaceNerSupervisedTrainer._save_metrics_plot( + aggregated_metrics, + aggregated_labels, + tracker_client, + model_name, + ) logger.debug("Evaluation metrics: %s", metrics) return metrics diff --git a/app/utils.py b/app/utils.py index 6979f55..9da3e75 100644 --- a/app/utils.py +++ b/app/utils.py @@ -528,6 +528,23 @@ def get_hf_pipeline_device_id(device: str) -> int: return device_id +def get_hf_device_map(device: str) -> Dict: + """ + Retrieves the device map for a Hugging Face model based on the specified device string. + + Args: + device (str): The string representation of the device. + + Returns: + Dict: The device map for the Hugging Face model. + """ + + if device.startswith(Device.GPU.value) or device.startswith(Device.MPS.value): + return {"": device} + else: + return {"": "cpu"} + + def unpack_model_data_package(model_data_file_path: str, model_data_folder_path: str) -> bool: """ Unpacks a model data package from a zip or tar.gz file into the specified folder. diff --git a/pyproject.toml b/pyproject.toml index 4c5d3e9..8b9be40 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,6 +35,8 @@ dependencies = [ "toml~=0.10.2", "peft<0.14.0", "huggingface-hub~=0.33.0", + "spacy>3.8", + "seqeval>=1.2.2", ] readme = "README.md" keywords = ["natural-language-processing", "electronic-health-records", "clinical-data"] diff --git a/tests/app/conftest.py b/tests/app/conftest.py index 087ed07..1fd6715 100644 --- a/tests/app/conftest.py +++ b/tests/app/conftest.py @@ -105,6 +105,7 @@ def huggingface_ner_model(): def huggingface_llm_model(): config = Settings() config.BASE_MODEL_FILE = "huggingface_llm_model.tar.gz" + config.TRAINING_HF_TAGGING_SCHEME = "flat" model_service = HuggingFaceLlmModel(config, MODEL_PARENT_DIR) model_service.init_model() return model_service diff --git a/tests/app/model_services/test_huggingface_ner_model.py b/tests/app/model_services/test_huggingface_ner_model.py index 77eb78b..f617979 100644 --- a/tests/app/model_services/test_huggingface_ner_model.py +++ b/tests/app/model_services/test_huggingface_ner_model.py @@ -1,6 +1,8 @@ import os import tempfile from unittest.mock import Mock +import pandas as pd +import pytest from tests.app.conftest import MODEL_PARENT_DIR from transformers import PreTrainedModel, PreTrainedTokenizerBase from app import __version__ diff --git a/tests/app/processors/test_tagging.py b/tests/app/processors/test_tagging.py new file mode 100644 index 0000000..83ab907 --- /dev/null +++ b/tests/app/processors/test_tagging.py @@ -0,0 +1,470 @@ +import pytest +import pandas as pd +from unittest.mock import MagicMock +from torch import ones, zeros +from app.processors.tagging import TagProcessor +from app.domain import TaggingScheme + + +class TestAgregateBioesPredictions: + + def test_aggregate_bioes_predictions_empty_dataframe(self): + empty_df = pd.DataFrame() + text = "This is a test sentence." + + result = TagProcessor.aggregate_bioes_predictions(empty_df, text, True) + + assert result == [] + + def test_aggregate_bioes_predictions_only_o_tags(self): + df = pd.DataFrame([ + {"entity": "O", "score": 0.9, "start": 0, "end": 4}, + {"entity": "O", "score": 0.8, "start": 5, "end": 7}, + {"entity": "O", "score": 0.7, "start": 8, "end": 12}, + ]) + text = "This is a test" + + result = TagProcessor.aggregate_bioes_predictions(df, text, True) + + assert result == [] + + def test_aggregate_bioes_predictions_single_token_entities(self): + df = pd.DataFrame([ + {"entity": "S-DISEASE", "score": 0.9, "start": 0, "end": 7}, + {"entity": "O", "score": 0.8, "start": 8, "end": 12}, + {"entity": "S-MEDICATION", "score": 0.7, "start": 12, "end": 20}, + ]) + text = "Disease and medicine" + + result = TagProcessor.aggregate_bioes_predictions(df, text, True) + + assert len(result) == 2 + assert result[0]["entity_group"] == "DISEASE" + assert result[0]["label_name"] == "DISEASE" + assert result[0]["start"] == 0 + assert result[0]["end"] == 7 + assert result[0]["text"] == "Disease" + assert result[0]["score"] == 0.9 + assert result[0]["accuracy"] == 0.9 + assert result[1]["entity_group"] == "MEDICATION" + assert result[1]["label_name"] == "MEDICATION" + assert result[1]["start"] == 12 + assert result[1]["end"] == 20 + assert result[1]["text"] == "medicine" + assert result[1]["score"] == 0.7 + assert result[1]["accuracy"] == 0.7 + + def test_aggregate_bioes_predictions_multi_token_entities(self): + df = pd.DataFrame([ + {"entity": "B-DISEASE", "score": 0.9, "start": 0, "end": 4}, + {"entity": "I-DISEASE", "score": 0.8, "start": 4, "end": 11}, + {"entity": "E-DISEASE", "score": 0.7, "start": 11, "end": 18}, + {"entity": "O", "score": 0.8, "start": 19, "end": 27}, + ]) + text = "Heart disease and diabetes" + + result = TagProcessor.aggregate_bioes_predictions(df, text, True) + + assert len(result) == 1 + assert result[0]["entity_group"] == "DISEASE" + assert result[0]["label_name"] == "DISEASE" + assert result[0]["start"] == 0 + assert result[0]["end"] == 18 + assert result[0]["text"] == "Heart disease and " + assert abs(result[0]["score"] - (0.9 + 0.8 + 0.7) / 3) < 1e-6 + assert abs(result[0]["accuracy"] - (0.9 + 0.8 + 0.7) / 3) < 1e-6 + + def test_aggregate_bioes_predictions_beginning_entities(self): + df = pd.DataFrame([ + {"entity": "B-DISEASE", "score": 0.9, "start": 0, "end": 11}, + {"entity": "O", "score": 0.8, "start": 12, "end": 16}, + ]) + text = "Heart disease" + + result = TagProcessor.aggregate_bioes_predictions(df, text, True) + + assert len(result) == 1 + assert result[0]["entity_group"] == "DISEASE" + assert result[0]["start"] == 0 + assert result[0]["end"] == 11 + assert result[0]["score"] == 0.9 + assert result[0]["text"] == "Heart disea" + + def test_aggregate_bioes_predictions_inside_entities(self): + df = pd.DataFrame([ + {"entity": "I-DISEASE", "score": 0.9, "start": 0, "end": 5}, + {"entity": "I-DISEASE", "score": 0.8, "start": 5, "end": 11}, + {"entity": "O", "score": 0.8, "start": 12, "end": 16}, + ]) + text = "Heart disease" + + result = TagProcessor.aggregate_bioes_predictions(df, text, True) + + assert len(result) == 1 + assert result[0]["entity_group"] == "DISEASE" + assert result[0]["start"] == 0 + assert result[0]["end"] == 11 + assert abs(result[0]["score"] - (0.9 + 0.8) / 2) < 1e-6 + + def test_aggregate_bioes_predictions_end_entities(self, huggingface_ner_model): + df = pd.DataFrame([ + {"entity": "O", "score": 0.8, "start": 0, "end": 10}, + {"entity": "E-DISEASE", "score": 0.9, "start": 10, "end": 17}, + ]) + text = "has heart disease" + + result = TagProcessor.aggregate_bioes_predictions(df, text, True) + + assert len(result) == 1 + assert result[0]["entity_group"] == "DISEASE" + assert result[0]["start"] == 10 + assert result[0]["end"] == 17 + assert result[0]["score"] == 0.9 + assert result[0]["text"] == "disease" + + def test_aggregate_bioes_predictions_no_prefix(self): + df = pd.DataFrame([ + {"entity": "DISEASE", "score": 0.9, "start": 0, "end": 5}, + {"entity": "O", "score": 0.8, "start": 6, "end": 10}, + ]) + text = "heart disease patient" + + result = TagProcessor.aggregate_bioes_predictions(df, text, True) + + assert len(result) == 1 + assert result[0]["entity_group"] == "DISEASE" + assert result[0]["start"] == 0 + assert result[0]["end"] == 5 + assert result[0]["score"] == 0.9 + assert result[0]["text"] == "heart" + + def test_aggregate_bioes_predictions_mixed_entities(self): + df = pd.DataFrame([ + {"entity": "S-DISEASE", "score": 0.9, "start": 0, "end": 5}, + {"entity": "B-MEDICATION", "score": 0.8, "start": 6, "end": 10}, + {"entity": "I-MEDICATION", "score": 0.7, "start": 10, "end": 16}, + {"entity": "O", "score": 0.8, "start": 17, "end": 20}, + {"entity": "E-SYMPTOM", "score": 0.6, "start": 21, "end": 26}, + ]) + text = "heart aspirin and cough" + + result = TagProcessor.aggregate_bioes_predictions(df, text, True) + + assert len(result) == 3 + + assert result[0]["entity_group"] == "DISEASE" + assert result[0]["start"] == 0 + assert result[0]["end"] == 5 + assert result[0]["score"] == 0.9 + assert result[1]["entity_group"] == "MEDICATION" + assert result[1]["start"] == 6 + assert result[1]["end"] == 16 + assert abs(result[1]["score"] - (0.8 + 0.7) / 2) < 1e-6 + assert result[2]["entity_group"] == "SYMPTOM" + assert result[2]["start"] == 21 + assert result[2]["end"] == 26 + + +class TestUpdateModelByTaggingScheme: + + @pytest.fixture + def mock_model(self): + model = MagicMock() + model.config = MagicMock() + model.config.label2id = {"O": 0, "B-PERSON": 1, "I-PERSON": 2} + model.config.id2label = {0: "O", 1: "B-PERSON", 2: "I-PERSON"} + model.classifier = MagicMock() + model.classifier.weight = ones(3, 10) + model.classifier.bias = zeros(3) + model.classifier.out_features = 3 + model.num_labels = 3 + return model + + def test_update_model_by_iob_scheme_new_concepts(self, mock_model): + concepts = ["DISEASE", "MEDICATION"] + initial_num_labels = mock_model.num_labels + initial_out_features = mock_model.classifier.out_features + updated_model = TagProcessor.update_model_by_tagging_scheme( + mock_model, concepts, TaggingScheme.IOB + ) + + assert "B-DISEASE" in updated_model.config.label2id + assert "I-DISEASE" in updated_model.config.label2id + assert "B-MEDICATION" in updated_model.config.label2id + assert "I-MEDICATION" in updated_model.config.label2id + assert updated_model.config.id2label[3] == "B-DISEASE" + assert updated_model.config.id2label[4] == "I-DISEASE" + assert updated_model.config.id2label[5] == "B-MEDICATION" + assert updated_model.config.id2label[6] == "I-MEDICATION" + assert updated_model.classifier.out_features == initial_out_features + 4 + assert updated_model.num_labels == initial_num_labels + 4 + + def test_update_model_by_iob_scheme_existing_concepts(self, mock_model): + concepts = ["PERSON", "DISEASE"] + initial_num_labels = mock_model.num_labels + initial_out_features = mock_model.classifier.out_features + + updated_model = TagProcessor.update_model_by_tagging_scheme( + mock_model, concepts, TaggingScheme.IOB + ) + + assert updated_model.num_labels == initial_num_labels + 2 + assert updated_model.classifier.out_features == initial_out_features + 2 + + def test_update_model_by_iobes_scheme_new_concepts(self, mock_model): + concepts = ["DISEASE", "MEDICATION"] + initial_num_labels = mock_model.num_labels + initial_out_features = mock_model.classifier.out_features + updated_model = TagProcessor.update_model_by_tagging_scheme( + mock_model, concepts, TaggingScheme.IOBES + ) + + for concept in concepts: + assert f"S-{concept}" in updated_model.config.label2id + assert f"B-{concept}" in updated_model.config.label2id + assert f"I-{concept}" in updated_model.config.label2id + assert f"E-{concept}" in updated_model.config.label2id + + assert updated_model.classifier.out_features == initial_out_features + 8 + assert updated_model.num_labels == initial_num_labels + 8 + + def test_update_model_by_iobes_scheme_existing_concepts(self, mock_model): + mock_model.config.label2id["B-DISEASE"] = 3 + mock_model.config.label2id["I-DISEASE"] = 4 + mock_model.config.id2label[3] = "B-DISEASE" + mock_model.config.id2label[4] = "I-DISEASE" + concepts = ["DISEASE"] + initial_num_labels = mock_model.num_labels + initial_out_features = mock_model.classifier.out_features + + updated_model = TagProcessor.update_model_by_tagging_scheme( + mock_model, concepts, TaggingScheme.IOBES + ) + + assert updated_model.num_labels == initial_num_labels + 2 + assert updated_model.classifier.out_features == initial_out_features + 2 + + def test_update_model_by_flat_scheme_new_concepts(self, mock_model): + concepts = ["DISEASE", "MEDICATION"] + initial_num_labels = mock_model.num_labels + initial_out_features = mock_model.classifier.out_features + + updated_model = TagProcessor.update_model_by_tagging_scheme( + mock_model, concepts, TaggingScheme.FLAT + ) + + assert "DISEASE" in updated_model.config.label2id + assert "MEDICATION" in updated_model.config.label2id + assert updated_model.config.id2label[3] == "DISEASE" + assert updated_model.config.id2label[4] == "MEDICATION" + assert updated_model.classifier.out_features == initial_out_features + 2 + assert updated_model.num_labels == initial_num_labels + 2 + + def test_update_model_by_flat_scheme_existing_concepts(self, mock_model): + mock_model.config.label2id["PERSON"] = 3 + mock_model.config.id2label[3] = "PERSON" + mock_model.num_labels = 4 + mock_model.classifier.out_features = 4 + + concepts = ["PERSON", "DISEASE"] + initial_num_labels = mock_model.num_labels + initial_out_features = mock_model.classifier.out_features + + updated_model = TagProcessor.update_model_by_tagging_scheme( + mock_model, concepts, TaggingScheme.FLAT + ) + + assert updated_model.num_labels == initial_num_labels + 1 + assert updated_model.classifier.out_features == initial_out_features + 1 + + def test_update_model_by_empty_concepts(self, mock_model): + concepts = [] + initial_num_labels = mock_model.num_labels + initial_out_features = mock_model.classifier.out_features + + updated_model = TagProcessor.update_model_by_tagging_scheme( + mock_model, concepts, TaggingScheme.IOB + ) + + assert updated_model.num_labels == initial_num_labels + assert updated_model.classifier.out_features == initial_out_features + + +class TestGenerateChuncksByTaggingScheme: + + @pytest.fixture + def mock_model(self): + model = MagicMock() + model.config = MagicMock() + model.config.label2id = { + "O": 0, + "B-DISEASE": 1, + "I-DISEASE": 2, + "S-DISEASE": 3, + "E-DISEASE": 4, + "B-MEDICATION": 5, + "I-MEDICATION": 6, + "DISEASE": 7, + "MEDICATION": 8, + } + return model + + def test_generate_chuncks_iob_scheme(self, mock_model): + annotations = [ + {"start": 5, "end": 10, "cui": "DISEASE"}, + {"start": 15, "end": 20, "cui": "MEDICATION"}, + ] + tokenized = { + "input_ids": [101, 102, 103, 104, 105, 106, 107, 108, 109, 110], + "attention_mask": [1, 1, 1, 1, 1, 1, 1, 1, 1, 1], + "offset_mapping": [ + (0, 3), (3, 5), (5, 8), (8, 10), (10, 12), (12, 15), (15, 18), (18, 20), (20, 23), (23, 25) + ], + } + + chunks = list(TagProcessor.generate_chuncks_by_tagging_scheme( + annotations=annotations, + tokenized=tokenized, + delfault_label_id=0, + pad_token_id=0, + pad_label_id=-100, + max_length=16, + model=mock_model, + tagging_scheme=TaggingScheme.IOB, + window_size=16, + stride=16, + )) + + assert len(chunks) == 1 + new_tokenized = chunks[0] + assert new_tokenized["labels"][2] == 1 + assert new_tokenized["labels"][3] == 2 + assert new_tokenized["labels"][6] == 5 + assert new_tokenized["labels"][7] == 6 + + def test_generate_chuncks_iobes_scheme(self, mock_model): + annotations = [{"start": 5, "end": 15, "cui": "DISEASE"}] + tokenized = { + "input_ids": [101, 102, 103, 104, 105, 106, 107, 108, 109], + "attention_mask": [1, 1, 1, 1, 1, 1, 1, 1, 1], + "offset_mapping": [(0, 3), (3, 5), (5, 8), (8, 10), (10, 12), (12, 15), (15, 18), (18, 20), (20, 23)], + } + + chunks = list(TagProcessor.generate_chuncks_by_tagging_scheme( + annotations=annotations, + tokenized=tokenized, + delfault_label_id=0, + pad_token_id=0, + pad_label_id=-100, + max_length=16, + model=mock_model, + tagging_scheme=TaggingScheme.IOBES, + window_size=16, + stride=16, + )) + + assert len(chunks) == 1 + new_tokenized = chunks[0] + assert new_tokenized["labels"][2] == 1 + assert new_tokenized["labels"][3] == 2 + assert new_tokenized["labels"][4] == 2 + assert new_tokenized["labels"][5] == 4 + + def test_generate_chuncks_flat_scheme(self, mock_model): + annotations = [ + {"start": 5, "end": 10, "cui": "DISEASE"}, + {"start": 15, "end": 20, "cui": "MEDICATION"}, + ] + tokenized = { + "input_ids": [101, 102, 103, 104, 105, 106, 107, 108, 109, 110], + "attention_mask": [1, 1, 1, 1, 1, 1, 1, 1, 1, 1], + "offset_mapping": [ + (0, 3), (3, 5), (5, 8), (8, 10), (10, 12), (12, 15), (15, 18), (18, 20), (20, 23), (23, 25) + ], + } + + chunks = list(TagProcessor.generate_chuncks_by_tagging_scheme( + annotations=annotations, + tokenized=tokenized, + delfault_label_id=0, + pad_token_id=0, + pad_label_id=-100, + max_length=16, + model=mock_model, + tagging_scheme=TaggingScheme.FLAT, + window_size=16, + stride=16, + )) + + assert len(chunks) == 1 + new_tokenized = chunks[0] + assert new_tokenized["labels"][2] == 7 + assert new_tokenized["labels"][3] == 7 + assert new_tokenized["labels"][6] == 8 + assert new_tokenized["labels"][7] == 8 + + + def test_generate_chuncks_empty_annotations(self, mock_model): + """Test that empty annotations list results in all default labels""" + annotations = [] + tokenized = { + "input_ids": [101, 102, 103, 104, 105], + "attention_mask": [1, 1, 1, 1, 1], + "offset_mapping": [(0, 3), (3, 5), (5, 8), (8, 10), (10, 12)], + } + + chunks = list(TagProcessor.generate_chuncks_by_tagging_scheme( + annotations=annotations, + tokenized=tokenized, + delfault_label_id=0, + pad_token_id=0, + pad_label_id=-100, + max_length=8, + model=mock_model, + tagging_scheme=TaggingScheme.IOB, + window_size=8, + stride=8, + )) + + assert len(chunks) == 1 + new_tokenized = chunks[0] + assert all(label == 0 for label in new_tokenized["labels"][:5]) + assert all(label == -100 for label in new_tokenized["labels"][5:]) + assert len(new_tokenized["input_ids"]) == 8 + assert len(new_tokenized["labels"]) == 8 + assert len(new_tokenized["attention_mask"]) == 8 + + + + def test_generate_chuncks_multiple_chunks(self, mock_model): + annotations = [{"start": 5, "end": 12, "cui": "DISEASE"}] + tokenized = { + "input_ids": [101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118], + "attention_mask": [1] * 18, + "offset_mapping": [(i*2, i*2+2) for i in range(18)], + } + + chunks = list(TagProcessor.generate_chuncks_by_tagging_scheme( + annotations=annotations, + tokenized=tokenized, + delfault_label_id=0, + pad_token_id=0, + pad_label_id=-100, + max_length=8, + model=mock_model, + tagging_scheme=TaggingScheme.IOB, + window_size=8, + stride=4, + )) + + + assert len(chunks) == 5 + for chunk in chunks: + assert len(chunk["input_ids"]) == 8 + assert len(chunk["labels"]) == 8 + assert len(chunk["attention_mask"]) == 8 + assert chunks[0]["input_ids"][:8] == tokenized["input_ids"][0:8] + assert chunks[1]["input_ids"][:8] == tokenized["input_ids"][4:12] + assert chunks[2]["input_ids"][:8] == tokenized["input_ids"][8:16] + assert chunks[3]["input_ids"][6:] == [0, 0] + assert chunks[4]["input_ids"][2:] == [0] * 6 diff --git a/tests/app/test_utils.py b/tests/app/test_utils.py index 2f00e10..3420fe7 100644 --- a/tests/app/test_utils.py +++ b/tests/app/test_utils.py @@ -26,6 +26,7 @@ safetensors_to_pytorch, non_default_device_is_available, get_hf_pipeline_device_id, + get_hf_device_map, get_model_data_package_extension, unpack_model_data_package, create_model_data_package, @@ -253,6 +254,13 @@ def test_get_hf_pipeline_device_id(): assert get_hf_pipeline_device_id("mps:1") == 1 +def test_get_hf_device_map(): + assert get_hf_device_map("cuda") == {"": "cuda"} + assert get_hf_device_map("mps") == {"": "mps"} + assert get_hf_device_map("cpu") == {"": "cpu"} + assert get_hf_device_map("unknown") == {"": "cpu"} + + def test_get_model_data_package_extension(): assert get_model_data_package_extension("model.zip") == ".zip" assert get_model_data_package_extension("model.tar.gz") == ".tar.gz" diff --git a/tests/app/trainers/test_hf_transformer_trainer.py b/tests/app/trainers/test_hf_transformer_trainer.py index c0da16d..f78b88b 100644 --- a/tests/app/trainers/test_hf_transformer_trainer.py +++ b/tests/app/trainers/test_hf_transformer_trainer.py @@ -13,6 +13,7 @@ _enable_trainer=True, _model_pack_path=os.path.join(model_parent_dir, "model.zip"), ) +model_service.model.config.max_position_embeddings = 512 unsupervised_trainer = HuggingFaceNerUnsupervisedTrainer(model_service) unsupervised_trainer.model_name = "unsupervised_trainer" supervised_trainer = HuggingFaceNerSupervisedTrainer(model_service) diff --git a/uv.lock b/uv.lock index 9474867..07a65b5 100644 --- a/uv.lock +++ b/uv.lock @@ -708,7 +708,9 @@ dependencies = [ { name = "python-dotenv" }, { name = "python-multipart" }, { name = "sentencepiece" }, + { name = "seqeval" }, { name = "slowapi" }, + { name = "spacy" }, { name = "toml" }, { name = "typer" }, { name = "uvicorn" }, @@ -811,7 +813,9 @@ requires-dist = [ { name = "python-multipart", specifier = "~=0.0.7" }, { name = "ruff", marker = "extra == 'dev'", specifier = "==0.6.9" }, { name = "sentencepiece", specifier = "~=0.2.0" }, + { name = "seqeval", specifier = ">=1.2.2" }, { name = "slowapi", specifier = "~=0.1.7" }, + { name = "spacy", specifier = ">3.8" }, { name = "sphinx", marker = "extra == 'docs'", specifier = "~=7.1.2" }, { name = "sphinx-autoapi", marker = "extra == 'docs'", specifier = "~=3.5.0" }, { name = "sphinx-autodoc-typehints", marker = "extra == 'docs'", specifier = "~=2.0.1" }, @@ -4458,6 +4462,16 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/c6/97/d159c32642306ee2b70732077632895438867b3b6df282354bd550cf2a67/sentencepiece-0.2.0-cp312-cp312-win_amd64.whl", hash = "sha256:7a673a72aab81fef5ebe755c6e0cc60087d1f3a4700835d40537183c1703a45f", size = 991994, upload-time = "2024-02-19T17:06:45.01Z" }, ] +[[package]] +name = "seqeval" +version = "1.2.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "numpy" }, + { name = "scikit-learn" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/9d/2d/233c79d5b4e5ab1dbf111242299153f3caddddbb691219f363ad55ce783d/seqeval-1.2.2.tar.gz", hash = "sha256:f28e97c3ab96d6fcd32b648f6438ff2e09cfba87f05939da9b3970713ec56e6f", size = 43605, upload-time = "2020-10-24T00:24:54.926Z" } + [[package]] name = "setuptools" version = "80.7.1"