Skip to content

Commit f50a23c

Browse files
Updates for MetaCAT (CogStack/MedCAT#515)
* Pushing update for MetaCAT - Addressing the multiple zero-division-error warnings per epoch while training - Accommodating the variations in category name and class name across NHS sites * Adding comments * Pushing requested changes * Pushing type fix * Pushing updates to metacat config
1 parent 2dacd6f commit f50a23c

File tree

4 files changed

+78
-23
lines changed

4 files changed

+78
-23
lines changed

medcat-v1/medcat/config_meta_cat.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Dict, Any
1+
from typing import Dict, Any, List
22
from medcat.config import MixingConfig, BaseModel, Optional
33

44

@@ -27,8 +27,22 @@ class General(MixingConfig, BaseModel):
2727
"""What category is this meta_cat model predicting/training.
2828
2929
NB! For these changes to take effect, the pipe would need to be recreated."""
30+
alternative_category_names: List = []
31+
"""List that stores the variations of possible category names
32+
Example: For Experiencer, the alternate name is Subject
33+
alternative_category_names: ['Experiencer','Subject']
34+
35+
In the case that one specified in self.general.category_name parameter does not match the data, this ensures no error is raised and it is automatically mapped
36+
"""
3037
category_value2id: Dict = {}
3138
"""Map from category values to ID, if empty it will be autocalculated during training"""
39+
alternative_class_names: List[List] = [[]]
40+
"""List of lists that stores the variations of possible class names for each class mentioned in self.general.category_value2id
41+
42+
Example: For Presence task, the class names vary across NHS sites.
43+
To accommodate for this, alternative_class_names is populated as: [["Hypothetical (N/A)","Hypothetical"],["Not present (False)","False"],["Present (True)","True"]]
44+
Each sub list contains the possible variations of the given class.
45+
"""
3246
vocab_size: Optional[int] = None
3347
"""Will be set automatically if the tokenizer is provided during meta_cat init"""
3448
lowercase: bool = True

medcat-v1/medcat/meta_cat.py

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -244,10 +244,17 @@ def train_raw(self, data_loaded: Dict, save_dir_path: Optional[str] = None, data
244244

245245
# Check is the name present
246246
category_name = g_config['category_name']
247+
category_name_options = g_config['alternative_category_names']
247248
if category_name not in data:
248-
raise Exception(
249-
"The category name does not exist in this json file. You've provided '{}', while the possible options are: {}".format(
250-
category_name, " | ".join(list(data.keys()))))
249+
category_matching = [cat for cat in category_name_options if cat in data.keys()]
250+
if len(category_matching) > 0:
251+
logger.info("The category name provided in the config - '%s' is not present in the data. However, the corresponding name - '%s' from the category_name_mapping has been found. Updating the category name...",category_name,*category_matching)
252+
g_config['category_name'] = category_matching[0]
253+
category_name = g_config['category_name']
254+
else:
255+
raise Exception(
256+
"The category name does not exist in this json file. You've provided '{}', while the possible options are: {}. Additionally, ensure the populate the 'alternative_category_names' attribute to accommodate for variations.".format(
257+
category_name, " | ".join(list(data.keys()))))
251258

252259
data = data[category_name]
253260
if data_oversampled:
@@ -258,27 +265,21 @@ def train_raw(self, data_loaded: Dict, save_dir_path: Optional[str] = None, data
258265
if not category_value2id:
259266
# Encode the category values
260267
full_data, data_undersampled, category_value2id = encode_category_values(data,
261-
category_undersample=self.config.model.category_undersample)
262-
g_config['category_value2id'] = category_value2id
268+
category_undersample=self.config.model.category_undersample,alternative_class_names=g_config['alternative_class_names'])
263269
else:
264270
# We already have everything, just get the data
265271
full_data, data_undersampled, category_value2id = encode_category_values(data,
266272
existing_category_value2id=category_value2id,
267-
category_undersample=self.config.model.category_undersample)
268-
g_config['category_value2id'] = category_value2id
269-
# Make sure the config number of classes is the same as the one found in the data
270-
if len(category_value2id) != self.config.model['nclasses']:
271-
logger.warning(
272-
"The number of classes set in the config is not the same as the one found in the data: %d vs %d",self.config.model['nclasses'], len(category_value2id))
273-
logger.warning("Auto-setting the nclasses value in config and rebuilding the model.")
274-
self.config.model['nclasses'] = len(category_value2id)
273+
category_undersample=self.config.model.category_undersample,alternative_class_names=g_config['alternative_class_names'])
274+
g_config['category_value2id'] = category_value2id
275+
self.config.model['nclasses'] = len(category_value2id)
275276

276277
if self.config.model.phase_number == 2 and save_dir_path is not None:
277278
model_save_path = os.path.join(save_dir_path, 'model.dat')
278279
device = torch.device(g_config['device'])
279280
try:
280281
self.model.load_state_dict(torch.load(model_save_path, map_location=device))
281-
logger.info("Model state loaded from dict for 2 phase learning")
282+
logger.info("Training model for Phase 2, with model dict loaded from disk")
282283

283284
except FileNotFoundError:
284285
raise FileNotFoundError(f"\nError: Model file not found at path: {model_save_path}\nPlease run phase 1 training and then run phase 2.")
@@ -295,6 +296,7 @@ def train_raw(self, data_loaded: Dict, save_dir_path: Optional[str] = None, data
295296
if not t_config['auto_save_model']:
296297
logger.info("For phase 1, model state has to be saved. Saving model...")
297298
t_config['auto_save_model'] = True
299+
logger.info("Training model for Phase 1 now...")
298300

299301
report = train_model(self.model, data=data, config=self.config, save_dir_path=save_dir_path)
300302

medcat-v1/medcat/utils/meta_cat/data_utils.py

Lines changed: 45 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from typing import Dict, Optional, Tuple, Iterable, List
22
from medcat.tokenizers.meta_cat_tokenizers import TokenizerWrapperBase
3+
import copy
34
import logging
45

56
logger = logging.getLogger(__name__)
@@ -153,7 +154,7 @@ def prepare_for_oversampled_data(data: List,
153154

154155

155156
def encode_category_values(data: Dict, existing_category_value2id: Optional[Dict] = None,
156-
category_undersample=None) -> Tuple:
157+
category_undersample=None, alternative_class_names: List[List] = []) -> Tuple:
157158
"""Converts the category values in the data outputted by `prepare_from_json`
158159
into integer values.
159160
@@ -164,6 +165,8 @@ def encode_category_values(data: Dict, existing_category_value2id: Optional[Dict
164165
Map from category_value to id (old/existing).
165166
category_undersample:
166167
Name of class that should be used to undersample the data (for 2 phase learning)
168+
alternative_class_names:
169+
Map that stores the variations of possible class names for the given category (task)
167170
168171
Returns:
169172
dict:
@@ -172,6 +175,9 @@ def encode_category_values(data: Dict, existing_category_value2id: Optional[Dict
172175
New undersampled data (for 2 phase learning) with integers inplace of strings for category values
173176
dict:
174177
Map from category value to ID for all categories in the data.
178+
179+
Raises:
180+
Exception: If categoryvalue2id is pre-defined and its labels do not match the labels found in the data
175181
"""
176182
data = list(data)
177183
if existing_category_value2id is not None:
@@ -180,9 +186,42 @@ def encode_category_values(data: Dict, existing_category_value2id: Optional[Dict
180186
category_value2id = {}
181187

182188
category_values = set([x[2] for x in data])
183-
for c in category_values:
184-
if c not in category_value2id:
185-
category_value2id[c] = len(category_value2id)
189+
190+
# If categoryvalue2id is pre-defined, then making sure it is same as the labels found in the data
191+
if len(category_value2id) != 0:
192+
if set(category_value2id.keys()) != category_values:
193+
# if categoryvalue2id doesn't match the labels in the data, then 'alternative_class_names' has to be defined to check for variations
194+
if len(alternative_class_names) != 0:
195+
updated_category_value2id = {}
196+
for _class in category_value2id.keys():
197+
if _class in category_values:
198+
updated_category_value2id[_class] = category_value2id[_class]
199+
else:
200+
found_in = [sub_map for sub_map in alternative_class_names if _class in sub_map]
201+
if len(found_in) != 0:
202+
class_name_matched = [label for label in found_in[0] if label in category_values]
203+
if len(class_name_matched) != 0:
204+
updated_category_value2id[class_name_matched] = category_value2id[_class]
205+
logger.info("Class name '%s' does not exist in the data; however a variation of it '%s' is present; updating it...",_class,class_name_matched)
206+
else:
207+
raise Exception(
208+
f"The classes set in the config are not the same as the one found in the data. The classes present in the config vs the ones found in the data - {set(category_value2id.keys())}, {category_values}. Additionally, ensure the populate the 'alternative_class_names' attribute to accommodate for variations.")
209+
else:
210+
raise Exception(f"The classes set in the config are not the same as the one found in the data. The classes present in the config vs the ones found in the data - {set(category_value2id.keys())}, {category_values}. Additionally, ensure the populate the 'alternative_class_names' attribute to accommodate for variations.")
211+
category_value2id = copy.deepcopy(updated_category_value2id)
212+
logger.info("Updated categoryvalue2id mapping - %s", category_value2id)
213+
214+
# Else throw an exception since the labels don't match
215+
else:
216+
raise Exception(
217+
f"The classes set in the config are not the same as the one found in the data. The classes present in the config vs the ones found in the data - {set(category_value2id.keys())}, {category_values}. Additionally, ensure the populate the 'alternative_class_names' attribute to accommodate for variations.")
218+
219+
# Else create the mapping from the labels found in the data
220+
else:
221+
for c in category_values:
222+
if c not in category_value2id:
223+
category_value2id[c] = len(category_value2id)
224+
logger.info("Categoryvalue2id mapping created with labels found in the data - %s", category_value2id)
186225

187226
# Map values to numbers
188227
for i in range(len(data)):
@@ -194,7 +233,7 @@ def encode_category_values(data: Dict, existing_category_value2id: Optional[Dict
194233
if data[i][2] in category_value2id.values():
195234
label_data_[data[i][2]] = label_data_[data[i][2]] + 1
196235

197-
logger.info("Original label_data: %s",label_data_)
236+
logger.info("Original number of samples per label: %s",label_data_)
198237
# Undersampling data
199238
if category_undersample is None or category_undersample == '':
200239
min_label = min(label_data_.values())
@@ -217,7 +256,7 @@ def encode_category_values(data: Dict, existing_category_value2id: Optional[Dict
217256
for i in range(len(data_undersampled)):
218257
if data_undersampled[i][2] in category_value2id.values():
219258
label_data[data_undersampled[i][2]] = label_data[data_undersampled[i][2]] + 1
220-
logger.info("Updated label_data: %s",label_data)
259+
logger.info("Updated number of samples per label (for 2-phase learning): %s",label_data)
221260

222261
return data, data_undersampled, category_value2id
223262

medcat-v1/medcat/utils/meta_cat/ml_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -329,12 +329,12 @@ def initialize_model(classifier, data_, batch_size_, lr_, epochs=4):
329329
print_report(epoch, running_loss_test, all_logits_test, y=y_test, name='Test')
330330

331331
_report = classification_report(y_test, np.argmax(np.concatenate(all_logits_test, axis=0), axis=1),
332-
output_dict=True)
332+
output_dict=True,zero_division=0)
333333
if not winner_report or _report[config.train['metric']['base']][config.train['metric']['score']] > \
334334
winner_report['report'][config.train['metric']['base']][config.train['metric']['score']]:
335335

336336
report = classification_report(y_test, np.argmax(np.concatenate(all_logits_test, axis=0), axis=1),
337-
output_dict=True)
337+
output_dict=True,zero_division=0)
338338
cm = confusion_matrix(y_test, np.argmax(np.concatenate(all_logits_test, axis=0), axis=1), normalize='true')
339339
report_train = classification_report(y_train, np.argmax(np.concatenate(all_logits, axis=0), axis=1),
340340
output_dict=True)

0 commit comments

Comments
 (0)