|
| 1 | +from typing import Dict, List |
1 | 2 | from sklearn.metrics import classification_report |
2 | 3 | import numpy as np |
3 | 4 | import pandas as pd |
4 | 5 | from collections import defaultdict |
5 | 6 | from scipy.special import softmax |
6 | 7 | import logging |
7 | 8 |
|
| 9 | +from medcat.cdb import CDB |
| 10 | + |
8 | 11 |
|
9 | 12 | logger = logging.getLogger(__name__) |
10 | 13 |
|
11 | 14 |
|
12 | 15 | def metrics(p, return_df=False, plus_recall=0, tokenizer=None, dataset=None, merged_negative={0, 1, -100}, padding_label=-100, csize=15, subword_label=1, |
13 | 16 | verbose=False): |
| 17 | + """ |
| 18 | + Calculate metrics for a model's predictions, based off the tokenized output of a MedCATTrainer project. |
| 19 | +
|
| 20 | + Args: |
| 21 | + p: The model's predictions. |
| 22 | + return_df: Whether to return a DataFrame of metrics. |
| 23 | + plus_recall: The recall to add to the model's predictions. |
| 24 | + tokenizer: The tokenizer used to tokenize the texts. |
| 25 | + """ |
14 | 26 | """TODO: This could be done better, for sure. But it works.""" # noqa |
15 | 27 | predictions = np.array(p.predictions) |
16 | 28 | predictions = softmax(predictions, axis=2) |
@@ -117,3 +129,88 @@ def metrics(p, return_df=False, plus_recall=0, tokenizer=None, dataset=None, mer |
117 | 129 | 'precison_merged': np.average([x for x in df.p_merged.values if pd.notna(x)])} |
118 | 130 | else: |
119 | 131 | return df, examples |
| 132 | + |
| 133 | + |
| 134 | +def _anno_within_pred_list(label: Dict, preds: List[Dict]) -> bool: |
| 135 | + """ |
| 136 | + Check if a label is within a list of predictions, |
| 137 | +
|
| 138 | + Args: |
| 139 | + label (Dict): an annotation likely from a MedCATTrainer project |
| 140 | + preds (List[Dict]): a list of predictions likely from a cat.__call__ |
| 141 | +
|
| 142 | + Returns: |
| 143 | + bool: True if the label is within the list of predictions, False otherwise |
| 144 | + """ |
| 145 | + return any(label['start'] >= p['start'] and label['end'] <= p['end'] for p in preds) |
| 146 | + |
| 147 | + |
| 148 | +def evaluate_predictions(true_annotations: List[List[Dict]], all_preds: List[List[Dict]], texts: List[str], deid_cdb: CDB): |
| 149 | + """ |
| 150 | + Evaluate predictions against sets of collected labels as collected and output from a MedCATTrainer project. |
| 151 | + Counts predictions as correct if the prediction fully encloses the label. |
| 152 | + |
| 153 | + Args: |
| 154 | + true_annotations (List[List[Dict]]): Ground truth predictions by text |
| 155 | + all_preds (List[List[Dict]]): Model predictions by text |
| 156 | + texts (List[str]): Original list of texts |
| 157 | + deid_cdb (CDB): Concept database |
| 158 | +
|
| 159 | + Returns: |
| 160 | + Tuple[pd.DataFrame, Dict]: A tuple containing a DataFrame of evaluation metrics and a dictionary of missed annotations per CUI. |
| 161 | + """ |
| 162 | + per_cui_recall = {} |
| 163 | + per_cui_prec = {} |
| 164 | + per_cui_recall_merged = {} |
| 165 | + per_cui_anno_counts = {} |
| 166 | + per_cui_annos_missed = defaultdict(list) |
| 167 | + uniq_labels = set([p['cui'] for ap in true_annotations for p in ap]) |
| 168 | + |
| 169 | + for cui in uniq_labels: |
| 170 | + # annos in test set |
| 171 | + anno_count = sum([len([p for p in cui_annos if p['cui'] == cui]) for cui_annos in true_annotations]) |
| 172 | + pred_counts = sum([len([p for p in d if p['cui'] == cui]) for d in all_preds]) |
| 173 | + |
| 174 | + # print(anno_count) |
| 175 | + # print(pred_counts) |
| 176 | + |
| 177 | + # print(f'pred_count: {pred_counts}, anno_count:{anno_count}') |
| 178 | + per_cui_anno_counts[cui] = anno_count |
| 179 | + |
| 180 | + doc_annos_left, preds_left, doc_annos_left_any_cui = [], [], [] |
| 181 | + |
| 182 | + for doc_preds, doc_labels, text in zip(all_preds, true_annotations, texts): |
| 183 | + # num of annos that are not found - recall |
| 184 | + cui_labels = [l for l in doc_labels if l['cui'] == cui] |
| 185 | + cui_doc_preds = [p for p in doc_preds if p['cui'] == cui] |
| 186 | + |
| 187 | + labels_not_found = [label for label in cui_labels if not _anno_within_pred_list(label, cui_doc_preds)] |
| 188 | + doc_annos_left.append(len(labels_not_found)) |
| 189 | + |
| 190 | + # num of annos that are not found across any cui prediction - recall_merged |
| 191 | + any_labels_not_found = [label for label in cui_labels if not _anno_within_pred_list(label, doc_preds)] |
| 192 | + doc_annos_left_any_cui.append(len(any_labels_not_found)) |
| 193 | + |
| 194 | + per_cui_annos_missed[cui].append(any_labels_not_found) |
| 195 | + |
| 196 | + # num of preds that are incorrect - precision |
| 197 | + preds_left.append(len([label for label in cui_doc_preds if not _anno_within_pred_list(label, cui_labels)])) |
| 198 | + |
| 199 | + if anno_count != 0 and pred_counts != 0: |
| 200 | + per_cui_recall[cui] = (anno_count - sum(doc_annos_left)) / anno_count |
| 201 | + per_cui_recall_merged[cui] = (anno_count - sum(doc_annos_left_any_cui)) / anno_count |
| 202 | + per_cui_prec[cui] = (pred_counts - sum(preds_left)) / pred_counts |
| 203 | + else: |
| 204 | + per_cui_recall[cui] = 0 |
| 205 | + per_cui_recall_merged[cui] = 0 |
| 206 | + per_cui_prec[cui] = 0 |
| 207 | + |
| 208 | + res_df = pd.DataFrame({ |
| 209 | + 'cui': per_cui_recall_merged.keys(), |
| 210 | + 'recall_merged': per_cui_recall_merged.values(), |
| 211 | + 'recall': per_cui_recall.values(), |
| 212 | + 'precision': per_cui_prec.values(), |
| 213 | + 'label_count': per_cui_anno_counts.values()}, index=[deid_cdb.cui2preferred_name[k] for k in per_cui_recall_merged]) |
| 214 | + |
| 215 | + return res_df, per_cui_annos_missed |
| 216 | + |
0 commit comments