|
2 | 2 | import json |
3 | 3 | import logging |
4 | 4 | import datasets |
| 5 | +import torch |
5 | 6 | from spacy.tokens import Doc |
6 | 7 | from datetime import datetime |
7 | 8 | from typing import Iterable, Iterator, Optional, Dict, List, cast, Union, Tuple, Callable, Type |
@@ -330,6 +331,63 @@ def save(self, save_dir_path: str) -> None: |
330 | 331 | # This is everything we need to save from the class, we do not |
331 | 332 | #save the class itself. |
332 | 333 |
|
| 334 | + def expand_model_with_concepts(self, cui2preferred_name: Dict[str, str], use_avg_init: bool = True) -> None: |
| 335 | + """Expand the model with new concepts and their preferred names, which requires subsequent retraining on the model. |
| 336 | +
|
| 337 | + Args: |
| 338 | + cui2preferred_name(Dict[str, str]): |
| 339 | + Dictionary where each key is the literal ID of the concept to be added and each value is its preferred name. |
| 340 | + use_avg_init(bool): |
| 341 | + Whether to use the average of existing weights or biases as the initial value for the new concept. Defaults to True. |
| 342 | + """ |
| 343 | + |
| 344 | + avg_weight = torch.mean(self.model.classifier.weight, dim=0, keepdim=True) |
| 345 | + avg_bias = torch.mean(self.model.classifier.bias, dim=0, keepdim=True) |
| 346 | + |
| 347 | + new_cuis = set() |
| 348 | + for label, preferred_name in cui2preferred_name.items(): |
| 349 | + if label in self.model.config.label2id.keys(): |
| 350 | + logger.warning("Concept ID '%s' already exists in the model, skipping...", label) |
| 351 | + continue |
| 352 | + |
| 353 | + sname = preferred_name.lower().replace(" ", "~") |
| 354 | + new_names = { |
| 355 | + sname: { |
| 356 | + "tokens": [], |
| 357 | + "snames": [sname], |
| 358 | + "raw_name": preferred_name, |
| 359 | + "is_upper": True |
| 360 | + } |
| 361 | + } |
| 362 | + self.cdb.add_names(cui=label, names=new_names, name_status="P", full_build=True) |
| 363 | + |
| 364 | + new_label_id = sorted(self.model.config.label2id.values())[-1] + 1 |
| 365 | + self.model.config.label2id[label] = new_label_id |
| 366 | + self.model.config.id2label[new_label_id] = label |
| 367 | + self.tokenizer.label_map[label] = new_label_id |
| 368 | + self.tokenizer.cui2name = {k: self.cdb.get_name(k) for k in self.tokenizer.label_map.keys()} |
| 369 | + |
| 370 | + if use_avg_init: |
| 371 | + self.model.classifier.weight = torch.nn.Parameter( |
| 372 | + torch.cat((self.model.classifier.weight, avg_weight), 0) |
| 373 | + ) |
| 374 | + self.model.classifier.bias = torch.nn.Parameter( |
| 375 | + torch.cat((self.model.classifier.bias, avg_bias), 0) |
| 376 | + ) |
| 377 | + else: |
| 378 | + self.model.classifier.weight = torch.nn.Parameter( |
| 379 | + torch.cat((self.model.classifier.weight, torch.randn(1, self.model.config.hidden_size)), 0) |
| 380 | + ) |
| 381 | + self.model.classifier.bias = torch.nn.Parameter( |
| 382 | + torch.cat((self.model.classifier.bias, torch.randn(1)), 0) |
| 383 | + ) |
| 384 | + self.model.num_labels += 1 |
| 385 | + self.model.classifier.out_features += 1 |
| 386 | + |
| 387 | + new_cuis.add(label) |
| 388 | + |
| 389 | + logger.info("Model expanded with the new concept(s): %s and shall be retrained before use.", str(new_cuis)) |
| 390 | + |
333 | 391 | @classmethod |
334 | 392 | def load(cls, save_dir_path: str, config_dict: Optional[Dict] = None) -> "TransformersNER": |
335 | 393 | """Load a meta_cat object. |
|
0 commit comments