11import logging
2- from typing import Any , Optional , Tuple , Union
2+ from typing import Any , Optional , Tuple
33import torch
44from torch import nn
55import os
66from transformers .models .llama import LlamaModel
77
88from medcat .config_rel_cat import ConfigRelCAT
9- from medcat .utils .relation_extraction .config import BaseConfig_RelationExtraction
109from medcat .utils .relation_extraction .llama .config import LlamaConfig_RelationExtraction
1110from medcat .utils .relation_extraction .models import BaseModel_RelationExtraction
1211from medcat .utils .relation_extraction .ml_utils import create_dense_layers , get_annotation_schema_tag
@@ -20,25 +19,25 @@ class LlamaModel_RelationExtraction(BaseModel_RelationExtraction):
2019
2120 log = logging .getLogger (__name__ )
2221
23- def __init__ (self , pretrained_model_name_or_path : str , relcat_config : ConfigRelCAT , model_config : Union [ BaseConfig_RelationExtraction , LlamaConfig_RelationExtraction ] ):
22+ def __init__ (self , pretrained_model_name_or_path : str , relcat_config : ConfigRelCAT , model_config : LlamaConfig_RelationExtraction ):
2423 """ Class to hold the Llama model + model_config
2524
2625 Args:
2726 pretrained_model_name_or_path (str): path to load the model from,
2827 this can be a HF model i.e: "bert-base-uncased", if left empty, it is normally assumed that a model is loaded from 'model.dat'
2928 using the RelCAT.load() method. So if you are initializing/training a model from scratch be sure to base it on some model.
3029 relcat_config (ConfigRelCAT): relcat config.
31- model_config (Union[BaseConfig_RelationExtraction | LlamaConfig_RelationExtraction] ): HF bert config for model.
30+ model_config (LlamaConfig_RelationExtraction): HF bert config for model.
3231 """
3332
3433 super (LlamaModel_RelationExtraction , self ).__init__ (pretrained_model_name_or_path = pretrained_model_name_or_path ,
3534 relcat_config = relcat_config ,
3635 model_config = model_config )
3736
3837 self .relcat_config : ConfigRelCAT = relcat_config
39- self .model_config : Union [ BaseConfig_RelationExtraction , LlamaConfig_RelationExtraction ] = model_config
38+ self .model_config = model_config
4039
41- self .hf_model : LlamaModel = LlamaModel (config = model_config ) # type: ignore
40+ self .hf_model : LlamaModel = LlamaModel (config = model_config . hf_model_config )
4241
4342 if pretrained_model_name_or_path != "" :
4443 self .hf_model = LlamaModel .from_pretrained (pretrained_model_name_or_path , config = model_config , ignore_mismatched_sizes = True )
@@ -162,8 +161,10 @@ def forward(self,
162161
163162 return model_output , classification_logits .to (self .relcat_config .general .device )
164163
164+ # NOTEL ignoring type due to the type of model_config not matching base class exactly (subclass)
165165 @classmethod
166- def load (cls , pretrained_model_name_or_path : str , relcat_config : ConfigRelCAT , model_config : Union [BaseConfig_RelationExtraction , LlamaConfig_RelationExtraction ], ** kwargs ) -> "LlamaModel_RelationExtraction" :
166+ def load (cls , pretrained_model_name_or_path : str , relcat_config : ConfigRelCAT , # type: ignore
167+ model_config : LlamaConfig_RelationExtraction , ** kwargs ) -> "LlamaModel_RelationExtraction" :
167168
168169 model = LlamaModel_RelationExtraction (pretrained_model_name_or_path = pretrained_model_name_or_path ,
169170 relcat_config = relcat_config ,
0 commit comments