Skip to content

Commit 7a62ebe

Browse files
authored
CU-86995mmb6 Fix some RelCAT typing issues (CogStack/MedCAT#542)
* CU-86995mmb6: Fix typing issues with modern bert config * CU-86995mmb6: Fix issue with llama tokenizer wrapper * CU-86995mmb6: Fix typing issues with llama config * CU-86995mmb6: Make bert model config types explicit as well * CU-86995mmb6: Remove excessive/unncessary type hints * CU-86995mmb6: Remove unncessary type ignore
1 parent 05e6cb5 commit 7a62ebe

File tree

8 files changed

+50
-32
lines changed

8 files changed

+50
-32
lines changed

medcat-v1/medcat/utils/relation_extraction/bert/config.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import logging
22
import os
3+
from typing import cast
4+
35
from medcat.config_rel_cat import ConfigRelCAT
46
from medcat.utils.relation_extraction.config import BaseConfig_RelationExtraction
57
from transformers import BertConfig
@@ -13,18 +15,19 @@ class BertConfig_RelationExtraction(BaseConfig_RelationExtraction):
1315

1416
name = 'bert-config'
1517
pretrained_model_name_or_path = "bert-base-uncased"
18+
hf_model_config: BertConfig
1619

1720
@classmethod
1821
def load(cls, pretrained_model_name_or_path: str, relcat_config: ConfigRelCAT, **kwargs) -> "BertConfig_RelationExtraction":
1922
model_config = cls(pretrained_model_name_or_path, **kwargs)
2023

2124
if pretrained_model_name_or_path and os.path.exists(pretrained_model_name_or_path):
22-
model_config.hf_model_config = BertConfig.from_json_file(pretrained_model_name_or_path)
25+
model_config.hf_model_config = cast(BertConfig, BertConfig.from_json_file(pretrained_model_name_or_path))
2326
logger.info("Loaded config from file: " + pretrained_model_name_or_path)
2427
else:
2528
relcat_config.general.model_name = cls.pretrained_model_name_or_path
26-
model_config.hf_model_config = BertConfig.from_pretrained(
27-
pretrained_model_name_or_path=cls.pretrained_model_name_or_path, **kwargs)
29+
model_config.hf_model_config = cast(BertConfig, BertConfig.from_pretrained(
30+
pretrained_model_name_or_path=cls.pretrained_model_name_or_path, **kwargs))
2831
logger.info("Loaded config from pretrained: " + relcat_config.general.model_name)
2932

3033
return model_config

medcat-v1/medcat/utils/relation_extraction/bert/model.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
from medcat.config_rel_cat import ConfigRelCAT
1313
from medcat.utils.relation_extraction.ml_utils import create_dense_layers
1414
from medcat.utils.relation_extraction.models import BaseModel_RelationExtraction
15-
from medcat.utils.relation_extraction.config import BaseConfig_RelationExtraction
1615
from medcat.utils.relation_extraction.bert.config import BertConfig_RelationExtraction
1716

1817

@@ -24,25 +23,25 @@ class BertModel_RelationExtraction(BaseModel_RelationExtraction):
2423

2524
log = logging.getLogger(__name__)
2625

27-
def __init__(self, pretrained_model_name_or_path: str, relcat_config: ConfigRelCAT, model_config: Union[BaseConfig_RelationExtraction, BertConfig_RelationExtraction]):
26+
def __init__(self, pretrained_model_name_or_path: str, relcat_config: ConfigRelCAT, model_config: BertConfig_RelationExtraction):
2827
""" Class to hold the BERT model + model_config
2928
3029
Args:
3130
pretrained_model_name_or_path (str): path to load the model from,
3231
this can be a HF model i.e: "bert-base-uncased", if left empty, it is normally assumed that a model is loaded from 'model.dat'
3332
using the RelCAT.load() method. So if you are initializing/training a model from scratch be sure to base it on some model.
3433
relcat_config (ConfigRelCAT): relcat config.
35-
model_config (Union[BaseConfig_RelationExtraction | BertConfig_RelationExtraction]): HF bert config for model.
34+
model_config (BertConfig_RelationExtraction): HF bert config for model.
3635
"""
3736
super(BertModel_RelationExtraction, self).__init__(pretrained_model_name_or_path=pretrained_model_name_or_path,
3837
relcat_config=relcat_config,
3938
model_config=model_config)
4039

4140
self.relcat_config: ConfigRelCAT = relcat_config
42-
self.model_config: Union[BaseConfig_RelationExtraction, BertConfig_RelationExtraction] = model_config
41+
self.model_config = model_config
4342
self.pretrained_model_name_or_path: str = pretrained_model_name_or_path
4443

45-
self.hf_model: Union[BertModel, PreTrainedModel] = BertModel(model_config.hf_model_config) # type: ignore
44+
self.hf_model: Union[BertModel, PreTrainedModel] = BertModel(model_config.hf_model_config)
4645

4746
for param in self.hf_model.parameters(): # type: ignore
4847
if self.relcat_config.model.freeze_layers:
@@ -55,8 +54,10 @@ def __init__(self, pretrained_model_name_or_path: str, relcat_config: ConfigRelC
5554
# dense layers
5655
self.fc1, self.fc2, self.fc3 = create_dense_layers(self.relcat_config)
5756

57+
# NOTEL ignoring type due to the type of model_config not matching base class exactly (subclass)
5858
@classmethod
59-
def load(cls, pretrained_model_name_or_path: str, relcat_config: ConfigRelCAT, model_config: Union[BaseConfig_RelationExtraction, BertConfig_RelationExtraction], **kwargs) -> "BertModel_RelationExtraction":
59+
def load(cls, pretrained_model_name_or_path: str, relcat_config: ConfigRelCAT, # type: ignore
60+
model_config: BertConfig_RelationExtraction, **kwargs) -> "BertModel_RelationExtraction":
6061

6162
model = BertModel_RelationExtraction(pretrained_model_name_or_path=pretrained_model_name_or_path,
6263
relcat_config=relcat_config,

medcat-v1/medcat/utils/relation_extraction/llama/config.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import logging
22
import os
3+
from typing import cast
4+
35
from medcat.config_rel_cat import ConfigRelCAT
46
from medcat.utils.relation_extraction.config import BaseConfig_RelationExtraction
57
from transformers import LlamaConfig
@@ -13,18 +15,19 @@ class LlamaConfig_RelationExtraction(BaseConfig_RelationExtraction):
1315

1416
name = 'llama-config'
1517
pretrained_model_name_or_path = "meta-llama/Llama-3.1-8B"
18+
hf_model_config: LlamaConfig
1619

1720
@classmethod
1821
def load(cls, pretrained_model_name_or_path: str, relcat_config: ConfigRelCAT, **kwargs) -> "LlamaConfig_RelationExtraction":
1922
model_config = cls(pretrained_model_name_or_path, **kwargs)
2023

2124
if pretrained_model_name_or_path and os.path.exists(pretrained_model_name_or_path):
22-
model_config.hf_model_config = LlamaConfig.from_json_file(pretrained_model_name_or_path)
25+
model_config.hf_model_config = cast(LlamaConfig, LlamaConfig.from_json_file(pretrained_model_name_or_path))
2326
logger.info("Loaded config from file: " + pretrained_model_name_or_path)
2427
else:
2528
relcat_config.general.model_name = cls.pretrained_model_name_or_path
26-
model_config.hf_model_config = LlamaConfig.from_pretrained(
27-
pretrained_model_name_or_path=cls.pretrained_model_name_or_path, **kwargs)
29+
model_config.hf_model_config = cast(LlamaConfig, LlamaConfig.from_pretrained(
30+
pretrained_model_name_or_path=cls.pretrained_model_name_or_path, **kwargs))
2831
logger.info("Loaded config from pretrained: " + relcat_config.general.model_name)
2932

3033
return model_config

medcat-v1/medcat/utils/relation_extraction/llama/model.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
11
import logging
2-
from typing import Any, Optional, Tuple, Union
2+
from typing import Any, Optional, Tuple
33
import torch
44
from torch import nn
55
import os
66
from transformers.models.llama import LlamaModel
77

88
from medcat.config_rel_cat import ConfigRelCAT
9-
from medcat.utils.relation_extraction.config import BaseConfig_RelationExtraction
109
from medcat.utils.relation_extraction.llama.config import LlamaConfig_RelationExtraction
1110
from medcat.utils.relation_extraction.models import BaseModel_RelationExtraction
1211
from medcat.utils.relation_extraction.ml_utils import create_dense_layers, get_annotation_schema_tag
@@ -20,25 +19,25 @@ class LlamaModel_RelationExtraction(BaseModel_RelationExtraction):
2019

2120
log = logging.getLogger(__name__)
2221

23-
def __init__(self, pretrained_model_name_or_path: str, relcat_config: ConfigRelCAT, model_config: Union[BaseConfig_RelationExtraction, LlamaConfig_RelationExtraction]):
22+
def __init__(self, pretrained_model_name_or_path: str, relcat_config: ConfigRelCAT, model_config: LlamaConfig_RelationExtraction):
2423
""" Class to hold the Llama model + model_config
2524
2625
Args:
2726
pretrained_model_name_or_path (str): path to load the model from,
2827
this can be a HF model i.e: "bert-base-uncased", if left empty, it is normally assumed that a model is loaded from 'model.dat'
2928
using the RelCAT.load() method. So if you are initializing/training a model from scratch be sure to base it on some model.
3029
relcat_config (ConfigRelCAT): relcat config.
31-
model_config (Union[BaseConfig_RelationExtraction | LlamaConfig_RelationExtraction]): HF bert config for model.
30+
model_config (LlamaConfig_RelationExtraction): HF bert config for model.
3231
"""
3332

3433
super(LlamaModel_RelationExtraction, self).__init__(pretrained_model_name_or_path=pretrained_model_name_or_path,
3534
relcat_config=relcat_config,
3635
model_config=model_config)
3736

3837
self.relcat_config: ConfigRelCAT = relcat_config
39-
self.model_config: Union[BaseConfig_RelationExtraction, LlamaConfig_RelationExtraction] = model_config
38+
self.model_config = model_config
4039

41-
self.hf_model: LlamaModel = LlamaModel(config=model_config) # type: ignore
40+
self.hf_model: LlamaModel = LlamaModel(config=model_config.hf_model_config)
4241

4342
if pretrained_model_name_or_path != "":
4443
self.hf_model = LlamaModel.from_pretrained(pretrained_model_name_or_path, config=model_config, ignore_mismatched_sizes=True)
@@ -162,8 +161,10 @@ def forward(self,
162161

163162
return model_output, classification_logits.to(self.relcat_config.general.device)
164163

164+
# NOTEL ignoring type due to the type of model_config not matching base class exactly (subclass)
165165
@classmethod
166-
def load(cls, pretrained_model_name_or_path: str, relcat_config: ConfigRelCAT, model_config: Union[BaseConfig_RelationExtraction, LlamaConfig_RelationExtraction], **kwargs) -> "LlamaModel_RelationExtraction":
166+
def load(cls, pretrained_model_name_or_path: str, relcat_config: ConfigRelCAT, # type: ignore
167+
model_config: LlamaConfig_RelationExtraction, **kwargs) -> "LlamaModel_RelationExtraction":
167168

168169
model = LlamaModel_RelationExtraction(pretrained_model_name_or_path=pretrained_model_name_or_path,
169170
relcat_config=relcat_config,

medcat-v1/medcat/utils/relation_extraction/llama/tokenizer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,5 +30,5 @@ def load(cls, tokenizer_path: str, relcat_config: ConfigRelCAT, **kwargs) -> "To
3030
else:
3131
relcat_config.general.model_name = cls.pretrained_model_name_or_path
3232
tokenizer.hf_tokenizers = LlamaTokenizerFast.from_pretrained(
33-
path=relcat_config.general.model_name)
33+
relcat_config.general.model_name)
3434
return tokenizer

medcat-v1/medcat/utils/relation_extraction/models.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import logging
22
import torch
3-
from typing import Any, Optional, Tuple, Union
3+
from typing import Any, Optional, Tuple, Union, cast
44
from torch import nn
55
from transformers import PretrainedConfig, PreTrainedModel
66

@@ -240,15 +240,21 @@ def load(cls, pretrained_model_name_or_path: str, relcat_config: ConfigRelCAT, m
240240
if "modern-bert" in relcat_config.general.tokenizer_name or \
241241
"modern-bert" in relcat_config.general.model_name:
242242
from medcat.utils.relation_extraction.modernbert.model import ModernBertModel_RelationExtraction
243-
model = ModernBertModel_RelationExtraction.load(pretrained_model_name_or_path, relcat_config=relcat_config, model_config=model_config)
243+
from medcat.utils.relation_extraction.modernbert.config import ModernBertConfig_RelationExtraction
244+
model = ModernBertModel_RelationExtraction.load(pretrained_model_name_or_path, relcat_config=relcat_config,
245+
model_config=cast(ModernBertConfig_RelationExtraction, model_config))
244246
elif "bert" in relcat_config.general.tokenizer_name or \
245247
"bert" in relcat_config.general.model_name:
246248
from medcat.utils.relation_extraction.bert.model import BertModel_RelationExtraction
247-
model = BertModel_RelationExtraction.load(pretrained_model_name_or_path, relcat_config=relcat_config, model_config=model_config)
249+
from medcat.utils.relation_extraction.bert.config import BertConfig_RelationExtraction
250+
model = BertModel_RelationExtraction.load(pretrained_model_name_or_path, relcat_config=relcat_config,
251+
model_config=cast(BertConfig_RelationExtraction, model_config))
248252
elif "llama" in relcat_config.general.tokenizer_name or \
249253
"llama" in relcat_config.general.model_name:
250254
from medcat.utils.relation_extraction.llama.model import LlamaModel_RelationExtraction
251-
model = LlamaModel_RelationExtraction.load(pretrained_model_name_or_path, relcat_config=relcat_config, model_config=model_config)
255+
from medcat.utils.relation_extraction.llama.config import LlamaConfig_RelationExtraction
256+
model = LlamaModel_RelationExtraction.load(pretrained_model_name_or_path, relcat_config=relcat_config,
257+
model_config=cast(LlamaConfig_RelationExtraction, model_config))
252258
else:
253259
if pretrained_model_name_or_path:
254260
model.hf_model = PreTrainedModel.from_pretrained(pretrained_model_name_or_path=pretrained_model_name_or_path, config=model_config)

medcat-v1/medcat/utils/relation_extraction/modernbert/config.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import logging
22
import os
3+
from typing import cast
4+
35
from medcat.config_rel_cat import ConfigRelCAT
46
from medcat.utils.relation_extraction.config import BaseConfig_RelationExtraction
57
from transformers import ModernBertConfig
@@ -13,18 +15,19 @@ class ModernBertConfig_RelationExtraction(BaseConfig_RelationExtraction):
1315

1416
name = 'modern-bert-config'
1517
pretrained_model_name_or_path = "answerdotai/ModernBERT-base"
18+
hf_model_config: ModernBertConfig
1619

1720
@classmethod
1821
def load(cls, pretrained_model_name_or_path: str, relcat_config: ConfigRelCAT, **kwargs) -> "ModernBertConfig_RelationExtraction":
1922
model_config = cls(pretrained_model_name_or_path=pretrained_model_name_or_path, **kwargs)
2023

2124
if pretrained_model_name_or_path and os.path.exists(pretrained_model_name_or_path):
22-
model_config.hf_model_config = ModernBertConfig.from_json_file(pretrained_model_name_or_path)
25+
model_config.hf_model_config = cast(ModernBertConfig, ModernBertConfig.from_json_file(pretrained_model_name_or_path))
2326
logger.info("Loaded config from file: " + pretrained_model_name_or_path)
2427
else:
2528
relcat_config.general.model_name = cls.pretrained_model_name_or_path
26-
model_config.hf_model_config = ModernBertConfig.from_pretrained(
27-
pretrained_model_name_or_path=cls.pretrained_model_name_or_path, **kwargs)
29+
model_config.hf_model_config = cast(ModernBertConfig, ModernBertConfig.from_pretrained(
30+
pretrained_model_name_or_path=cls.pretrained_model_name_or_path, **kwargs))
2831
logger.info("Loaded config from pretrained: " + relcat_config.general.model_name)
2932

3033
return model_config

medcat-v1/medcat/utils/relation_extraction/modernbert/model.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
from transformers import PreTrainedModel
1010
from medcat.utils.relation_extraction.ml_utils import create_dense_layers
1111
from medcat.utils.relation_extraction.models import BaseModel_RelationExtraction
12-
from medcat.utils.relation_extraction.config import BaseConfig_RelationExtraction
1312
from medcat.utils.relation_extraction.modernbert.config import ModernBertConfig_RelationExtraction
1413

1514

@@ -22,22 +21,22 @@ class ModernBertModel_RelationExtraction(BaseModel_RelationExtraction):
2221

2322
log = logging.getLogger(__name__)
2423

25-
def __init__(self, pretrained_model_name_or_path: str, relcat_config: ConfigRelCAT, model_config: Union[BaseConfig_RelationExtraction, ModernBertConfig_RelationExtraction]):
24+
def __init__(self, pretrained_model_name_or_path: str, relcat_config: ConfigRelCAT, model_config: ModernBertConfig_RelationExtraction):
2625
""" Class to hold the ModernBERT model + model_config
2726
2827
Args:
2928
pretrained_model_name_or_path (str): path to load the model from,
3029
this can be a HF model i.e: "bert-base-uncased", if left empty, it is normally assumed that a model is loaded from 'model.dat'
3130
using the RelCAT.load() method. So if you are initializing/training a model from scratch be sure to base it on some model.
3231
relcat_config (ConfigRelCAT): relcat config.
33-
model_config (Union[BaseConfig_RelationExtraction | ModernBertConfig_RelationExtraction]): HF bert config for model.
32+
model_config (ModernBertConfig_RelationExtraction): HF bert config for model.
3433
"""
3534
super(ModernBertModel_RelationExtraction, self).__init__(pretrained_model_name_or_path=pretrained_model_name_or_path,
3635
relcat_config=relcat_config,
3736
model_config=model_config)
3837

3938
self.relcat_config: ConfigRelCAT = relcat_config
40-
self.model_config: Union[BaseConfig_RelationExtraction, ModernBertConfig_RelationExtraction] = model_config
39+
self.model_config = model_config
4140
self.pretrained_model_name_or_path: str = pretrained_model_name_or_path
4241

4342
self.hf_model: Union[ModernBertModel, PreTrainedModel] = ModernBertModel(config=model_config.hf_model_config)
@@ -53,8 +52,10 @@ def __init__(self, pretrained_model_name_or_path: str, relcat_config: ConfigRelC
5352
# dense layers
5453
self.fc1, self.fc2, self.fc3 = create_dense_layers(self.relcat_config)
5554

55+
# NOTEL ignoring type due to the type of model_config not matchin base class exactly (subclass)
5656
@classmethod
57-
def load(cls, pretrained_model_name_or_path: str, relcat_config: ConfigRelCAT, model_config: Union[BaseConfig_RelationExtraction, ModernBertConfig_RelationExtraction], **kwargs) -> "ModernBertModel_RelationExtraction":
57+
def load(cls, pretrained_model_name_or_path: str, relcat_config: ConfigRelCAT, # type: ignore
58+
model_config: ModernBertConfig_RelationExtraction, **kwargs) -> "ModernBertModel_RelationExtraction":
5859

5960
model = ModernBertModel_RelationExtraction(pretrained_model_name_or_path=pretrained_model_name_or_path,
6061
relcat_config=relcat_config,

0 commit comments

Comments
 (0)