Skip to content

Commit ee3c640

Browse files
authored
Support expansion of transformers ner models to include new concepts (CogStack/MedCAT#519)
* CU-8697v6qr2 support expansion of transformers ner models to include new concepts * CU-8697v6qr2 add logging suggested by the review
1 parent f50a23c commit ee3c640

File tree

4 files changed

+97
-0
lines changed

4 files changed

+97
-0
lines changed

medcat-v1/medcat/ner/transformers_ner.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import json
33
import logging
44
import datasets
5+
import torch
56
from spacy.tokens import Doc
67
from datetime import datetime
78
from typing import Iterable, Iterator, Optional, Dict, List, cast, Union, Tuple, Callable, Type
@@ -330,6 +331,63 @@ def save(self, save_dir_path: str) -> None:
330331
# This is everything we need to save from the class, we do not
331332
#save the class itself.
332333

334+
def expand_model_with_concepts(self, cui2preferred_name: Dict[str, str], use_avg_init: bool = True) -> None:
335+
"""Expand the model with new concepts and their preferred names, which requires subsequent retraining on the model.
336+
337+
Args:
338+
cui2preferred_name(Dict[str, str]):
339+
Dictionary where each key is the literal ID of the concept to be added and each value is its preferred name.
340+
use_avg_init(bool):
341+
Whether to use the average of existing weights or biases as the initial value for the new concept. Defaults to True.
342+
"""
343+
344+
avg_weight = torch.mean(self.model.classifier.weight, dim=0, keepdim=True)
345+
avg_bias = torch.mean(self.model.classifier.bias, dim=0, keepdim=True)
346+
347+
new_cuis = set()
348+
for label, preferred_name in cui2preferred_name.items():
349+
if label in self.model.config.label2id.keys():
350+
logger.warning("Concept ID '%s' already exists in the model, skipping...", label)
351+
continue
352+
353+
sname = preferred_name.lower().replace(" ", "~")
354+
new_names = {
355+
sname: {
356+
"tokens": [],
357+
"snames": [sname],
358+
"raw_name": preferred_name,
359+
"is_upper": True
360+
}
361+
}
362+
self.cdb.add_names(cui=label, names=new_names, name_status="P", full_build=True)
363+
364+
new_label_id = sorted(self.model.config.label2id.values())[-1] + 1
365+
self.model.config.label2id[label] = new_label_id
366+
self.model.config.id2label[new_label_id] = label
367+
self.tokenizer.label_map[label] = new_label_id
368+
self.tokenizer.cui2name = {k: self.cdb.get_name(k) for k in self.tokenizer.label_map.keys()}
369+
370+
if use_avg_init:
371+
self.model.classifier.weight = torch.nn.Parameter(
372+
torch.cat((self.model.classifier.weight, avg_weight), 0)
373+
)
374+
self.model.classifier.bias = torch.nn.Parameter(
375+
torch.cat((self.model.classifier.bias, avg_bias), 0)
376+
)
377+
else:
378+
self.model.classifier.weight = torch.nn.Parameter(
379+
torch.cat((self.model.classifier.weight, torch.randn(1, self.model.config.hidden_size)), 0)
380+
)
381+
self.model.classifier.bias = torch.nn.Parameter(
382+
torch.cat((self.model.classifier.bias, torch.randn(1)), 0)
383+
)
384+
self.model.num_labels += 1
385+
self.model.classifier.out_features += 1
386+
387+
new_cuis.add(label)
388+
389+
logger.info("Model expanded with the new concept(s): %s and shall be retrained before use.", str(new_cuis))
390+
333391
@classmethod
334392
def load(cls, save_dir_path: str, config_dict: Optional[Dict] = None) -> "TransformersNER":
335393
"""Load a meta_cat object.

medcat-v1/medcat/utils/ner/model.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,21 @@ def get_entities(self, text: str, *args, **kwargs) -> dict:
7676
"""
7777
return self.cat.get_entities(text, *args, **kwargs)
7878

79+
def add_new_concepts(self,
80+
cui2preferred_name: Dict[str, str],
81+
train_nr: int = 0,
82+
with_random_init: bool = False) -> None:
83+
"""Add new concepts to the model and the concept database.
84+
85+
Invoking this requires subsequent retraining on the model.
86+
87+
Args:
88+
cui2preferred_name(Dict[str, str]): Dictionary where each key is the literal ID of the concept to be added and each value is its preferred name.
89+
train_nr (int): The number of the NER object in cat._addl_train to which new concepts will be added. Defaults to 0.
90+
with_random_init (bool): Whether to use the random init strategy for the new concepts. Defaults to False.
91+
"""
92+
self.cat._addl_ner[train_nr].expand_model_with_concepts(cui2preferred_name, use_avg_init=not with_random_init)
93+
7994
@property
8095
def config(self) -> Config:
8196
return self.cat.config

medcat-v1/tests/ner/test_transformers_ner.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,3 +48,20 @@ def on_epoch_end(self, *args, **kwargs) -> None:
4848
assert dataset["train"].num_rows == 48
4949
assert dataset["test"].num_rows == 12
5050
self.assertEqual(tracker.call.call_count, 2)
51+
52+
def test_expand_model_with_concepts(self):
53+
original_num_labels = self.undertest.model.num_labels
54+
original_out_features = self.undertest.model.classifier.out_features
55+
original_label_map_size = len(self.undertest.tokenizer.label_map)
56+
cui2preferred_name = {
57+
"concept_1" : "Preferred Name 1",
58+
"concept_2" : "Preferred Name 2",
59+
}
60+
61+
self.undertest.expand_model_with_concepts(cui2preferred_name)
62+
63+
assert self.undertest.model.num_labels == original_num_labels + len(cui2preferred_name)
64+
assert self.undertest.model.classifier.out_features == original_out_features + len(cui2preferred_name)
65+
assert len(self.undertest.tokenizer.label_map) == original_label_map_size + len(cui2preferred_name)
66+
assert self.undertest.tokenizer.cui2name.get("concept_1") == "Preferred Name 1"
67+
assert self.undertest.tokenizer.cui2name.get("concept_2") == "Preferred Name 2"

medcat-v1/tests/utils/ner/test_deid.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,13 @@ def test_training(self):
9090
self.assertIsNotNone(examples)
9191
self.assertIsNotNone(dataset)
9292

93+
def test_add_new_concepts(self):
94+
self.deid_model.add_new_concepts({'CONCEPT': "Concept"}, with_random_init=True)
95+
self.assertTrue("CONCEPT" in self.deid_model.cat.cdb.cui2names)
96+
self.assertEqual(self.deid_model.cat.cdb.cui2names["CONCEPT"], {"concept"})
97+
self.assertTrue("CONCEPT" in self.deid_model.cat._addl_ner[0].model.config.label2id)
98+
self.assertTrue("CONCEPT" in self.deid_model.cat._addl_ner[0].tokenizer.label_map)
99+
self.assertTrue("CONCEPT" in self.deid_model.cat._addl_ner[0].tokenizer.cui2name)
93100

94101
input_text = '''
95102
James Joyce

0 commit comments

Comments
 (0)