Skip to content

Commit 2dacd6f

Browse files
authored
CU-8697x7y9x: Fix issue with transformers 4.47+ affecting DeID (CogStack/MedCAT#517)
* CU-8697x7y9x: Fix issue with transformers 4.47+ affecting DeID * CU-8697x7y9x: Add type-ignore to module unrelated to current change
1 parent 8769064 commit 2dacd6f

File tree

2 files changed

+16
-2
lines changed

2 files changed

+16
-2
lines changed

medcat-v1/medcat/ner/transformers_ner.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,8 +89,22 @@ def create_eval_pipeline(self):
8989
self.ner_pipe.tokenizer._in_target_context_manager = False
9090
if not hasattr(self.ner_pipe.tokenizer, 'split_special_tokens'):
9191
# NOTE: this will fix the DeID model(s) created with transformers before 4.42
92-
# and allow them to run with later transforemrs
92+
# and allow them to run with later transformers
9393
self.ner_pipe.tokenizer.split_special_tokens = False
94+
if not hasattr(self.ner_pipe.tokenizer, 'pad_token') and hasattr(self.ner_pipe.tokenizer, '_pad_token'):
95+
# NOTE: This will fix the DeID model(s) created with transformers before 4.47
96+
# and allow them to run with later transformmers versions
97+
# In 4.47 the special tokens started to be used differently, yet our saved model
98+
# is not aware of that. So we need to explicitly fix that.
99+
special_tokens_map = self.ner_pipe.tokenizer.__dict__.get('_special_tokens_map', {})
100+
for name in self.ner_pipe.tokenizer.SPECIAL_TOKENS_ATTRIBUTES:
101+
# previously saved in (e.g) _pad_token
102+
prev_val = getattr(self.ner_pipe.tokenizer, f"_{name}")
103+
# now saved in the special tokens map by its name
104+
special_tokens_map[name] = prev_val
105+
# the map is saved in __dict__ explicitly, and it is later used in __getattr__ of the base class.
106+
self.ner_pipe.tokenizer.__dict__['_special_tokens_map'] = special_tokens_map
107+
94108
self.ner_pipe.device = self.model.device
95109
self._consecutive_identical_failures = 0
96110
self._last_exception: Optional[Tuple[str, Type[Exception]]] = None

medcat-v1/medcat/utils/relation_extraction/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ def load_state(model: BertModel_RelationExtraction, optimizer, scheduler, path="
146146

147147
if optimizer is None:
148148
optimizer = torch.optim.Adam(
149-
[{"params": model.module.parameters(), "lr": config.train.lr}])
149+
[{"params": model.module.parameters(), "lr": config.train.lr}]) # type: ignore
150150

151151
if scheduler is None:
152152
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,

0 commit comments

Comments
 (0)