Skip to content

Commit 720dabb

Browse files
author
Tom Searle
committed
CU-86995ddvj: Add in post-processing funcs for a de-id pipeline
1 parent 98302e1 commit 720dabb

File tree

2 files changed

+202
-0
lines changed

2 files changed

+202
-0
lines changed

medcat-v1/medcat/utils/ner/deid.py

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
- config
3535
- cdb
3636
"""
37+
import re
3738
from typing import Union, Tuple, Any, List, Iterable, Optional, Dict
3839
import logging
3940

@@ -187,3 +188,107 @@ def _get_reason_not_deid(cls, cat: CAT) -> str:
187188
if len(cat._addl_ner) != 1:
188189
return f"Incorrect number of addl_ner: {len(cat._addl_ner)}"
189190
return ""
191+
192+
193+
def match_rules(rules: List[Tuple[str, str]], texts: List[str], cat: CAT):
194+
"""
195+
Match a set of rules - pat / cui combos as post processing labels, uses
196+
a cat DeID model forp pretty name mapping
197+
198+
Examples:
199+
>>> rules = [
200+
('(123) 456-7890', '134'),
201+
('1234567890', '134'),
202+
('123.456.7890', '134'),
203+
('1234567890', '134'),
204+
('1234567890', '134'),
205+
]
206+
>>> texts = [
207+
'My phone number is (123) 456-7890',
208+
'My phone number is 1234567890',
209+
'My phone number is 123.456.7890',
210+
'My phone number is 1234567890',
211+
]
212+
>>> matches = match_rules(rules, texts, cat)
213+
"""
214+
# Iterate through each text and pattern combination
215+
rule_matches_per_text = []
216+
for i, text in enumerate(texts):
217+
matches_in_text = []
218+
for pattern, concept in rules:
219+
# Find all matches of current pattern in current text
220+
text_matches = re.finditer(pattern, text, flags=re.M)
221+
222+
# Add each match with its pattern and text info
223+
for match in text_matches:
224+
matches_in_text.append({
225+
'source_value': match.group(),
226+
'pretty_name': cat.cdb.cui2preferred_name[concept],
227+
'start': match.start(),
228+
'end': match.end(),
229+
'cui': concept,
230+
'acc': 1.0,
231+
'soure_value': match.group(0)
232+
})
233+
rule_matches_per_text.append(matches_in_text)
234+
return rule_matches_per_text
235+
236+
237+
def merge_preds(model_preds_by_text: List[Dict], rule_matches_per_text: List[Dict], accept_preds=True):
238+
"""
239+
Merge predictions from rule based and deID model predictions for further evaluation
240+
241+
Args:
242+
model_preds_by_text (List[Dict]): list of predictions from `cat.get_entities()`, then `[list(m['entities'].values()) for m in model_preds]`
243+
rule_matches_by_text (List[Dict]): list of predictions from output of running `match_rules`
244+
accept_preds (bool): uses the predicted label from the model, model_preds_by_text, over the rule matches if they overlap. Defaults to using model preds over rules.
245+
246+
Examples:
247+
>>> # a list of lists of predictions from `cat.get_entities()`
248+
>>> model_preds_by_text = [
249+
[
250+
{'cui': '134', 'start': 10, 'end': 20, 'acc': 1.0, 'pretty_name': 'Phone Number'},
251+
{'cui': '134', 'start': 25, 'end': 35, 'acc': 1.0, 'pretty_name': 'Phone Number'}
252+
]
253+
]
254+
>>> # a list of lists of predictions from `match_rules`
255+
>>> rule_matches_by_text = [
256+
[
257+
{'cui': '134', 'start': 10, 'end': 20, 'acc': 1.0, 'pretty_name': 'Phone Number'},
258+
{'cui': '134', 'start': 25, 'end': 35, 'acc': 1.0, 'pretty_name': 'Phone Number'}
259+
]
260+
]
261+
>>> merged_preds = merge_preds(model_preds_by_text, rule_matches_by_text)
262+
"""
263+
all_preds = []
264+
if accept_preds:
265+
labels1 = model_preds_by_text
266+
labels2 = rule_matches_per_text
267+
else:
268+
labels1 = rule_matches_per_text
269+
labels2 = model_preds_by_text
270+
271+
for matches_text1, matches_text2 in zip(labels1, labels2):
272+
# Function to check if two spans overlap
273+
def has_overlap(span1, span2):
274+
return not (span1['end'] <= span2['start'] or span2['end'] <= span1['start'])
275+
276+
# Mark model predictions that overlap with rule matches
277+
278+
to_remove = set()
279+
for text_match1 in matches_text1:
280+
for i, text_match2 in enumerate(matches_text2):
281+
if has_overlap(text_match1, text_match2):
282+
to_remove.add(i)
283+
284+
# Keep only non-overlapping model predictions
285+
matches_text2 = [text_match for i, text_match in enumerate(matches_text2) if i not in to_remove]
286+
287+
# merge preds and sort on start
288+
merged_preds = matches_text1 + matches_text2
289+
merged_preds.sort(key=lambda x: x['start'])
290+
all_preds.append(merged_preds)
291+
return all_preds
292+
293+
294+

medcat-v1/medcat/utils/ner/metrics.py

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,28 @@
1+
from typing import Dict, List
12
from sklearn.metrics import classification_report
23
import numpy as np
34
import pandas as pd
45
from collections import defaultdict
56
from scipy.special import softmax
67
import logging
78

9+
from medcat.cdb import CDB
10+
811

912
logger = logging.getLogger(__name__)
1013

1114

1215
def metrics(p, return_df=False, plus_recall=0, tokenizer=None, dataset=None, merged_negative={0, 1, -100}, padding_label=-100, csize=15, subword_label=1,
1316
verbose=False):
17+
"""
18+
Calculate metrics for a model's predictions, based off the tokenized output of a MedCATTrainer project.
19+
20+
Args:
21+
p: The model's predictions.
22+
return_df: Whether to return a DataFrame of metrics.
23+
plus_recall: The recall to add to the model's predictions.
24+
tokenizer: The tokenizer used to tokenize the texts.
25+
"""
1426
"""TODO: This could be done better, for sure. But it works.""" # noqa
1527
predictions = np.array(p.predictions)
1628
predictions = softmax(predictions, axis=2)
@@ -117,3 +129,88 @@ def metrics(p, return_df=False, plus_recall=0, tokenizer=None, dataset=None, mer
117129
'precison_merged': np.average([x for x in df.p_merged.values if pd.notna(x)])}
118130
else:
119131
return df, examples
132+
133+
134+
def _anno_within_pred_list(label: Dict, preds: List[Dict]) -> bool:
135+
"""
136+
Check if a label is within a list of predictions,
137+
138+
Args:
139+
label (Dict): an annotation likely from a MedCATTrainer project
140+
preds (List[Dict]): a list of predictions likely from a cat.__call__
141+
142+
Returns:
143+
bool: True if the label is within the list of predictions, False otherwise
144+
"""
145+
return any(label['start'] >= p['start'] and label['end'] <= p['end'] for p in preds)
146+
147+
148+
def evaluate_predictions(true_annotations: List[List[Dict]], all_preds: List[List[Dict]], texts: List[str], deid_cdb: CDB):
149+
"""
150+
Evaluate predictions against sets of collected labels as collected and output from a MedCATTrainer project.
151+
Counts predictions as correct if the prediction fully encloses the label.
152+
153+
Args:
154+
true_annotations (List[List[Dict]]): Ground truth predictions by text
155+
all_preds (List[List[Dict]]): Model predictions by text
156+
texts (List[str]): Original list of texts
157+
deid_cdb (CDB): Concept database
158+
159+
Returns:
160+
Tuple[pd.DataFrame, Dict]: A tuple containing a DataFrame of evaluation metrics and a dictionary of missed annotations per CUI.
161+
"""
162+
per_cui_recall = {}
163+
per_cui_prec = {}
164+
per_cui_recall_merged = {}
165+
per_cui_anno_counts = {}
166+
per_cui_annos_missed = defaultdict(list)
167+
uniq_labels = set([p['cui'] for ap in true_annotations for p in ap])
168+
169+
for cui in uniq_labels:
170+
# annos in test set
171+
anno_count = sum([len([p for p in cui_annos if p['cui'] == cui]) for cui_annos in true_annotations])
172+
pred_counts = sum([len([p for p in d if p['cui'] == cui]) for d in all_preds])
173+
174+
# print(anno_count)
175+
# print(pred_counts)
176+
177+
# print(f'pred_count: {pred_counts}, anno_count:{anno_count}')
178+
per_cui_anno_counts[cui] = anno_count
179+
180+
doc_annos_left, preds_left, doc_annos_left_any_cui = [], [], []
181+
182+
for doc_preds, doc_labels, text in zip(all_preds, true_annotations, texts):
183+
# num of annos that are not found - recall
184+
cui_labels = [l for l in doc_labels if l['cui'] == cui]
185+
cui_doc_preds = [p for p in doc_preds if p['cui'] == cui]
186+
187+
labels_not_found = [label for label in cui_labels if not _anno_within_pred_list(label, cui_doc_preds)]
188+
doc_annos_left.append(len(labels_not_found))
189+
190+
# num of annos that are not found across any cui prediction - recall_merged
191+
any_labels_not_found = [label for label in cui_labels if not _anno_within_pred_list(label, doc_preds)]
192+
doc_annos_left_any_cui.append(len(any_labels_not_found))
193+
194+
per_cui_annos_missed[cui].append(any_labels_not_found)
195+
196+
# num of preds that are incorrect - precision
197+
preds_left.append(len([label for label in cui_doc_preds if not _anno_within_pred_list(label, cui_labels)]))
198+
199+
if anno_count != 0 and pred_counts != 0:
200+
per_cui_recall[cui] = (anno_count - sum(doc_annos_left)) / anno_count
201+
per_cui_recall_merged[cui] = (anno_count - sum(doc_annos_left_any_cui)) / anno_count
202+
per_cui_prec[cui] = (pred_counts - sum(preds_left)) / pred_counts
203+
else:
204+
per_cui_recall[cui] = 0
205+
per_cui_recall_merged[cui] = 0
206+
per_cui_prec[cui] = 0
207+
208+
res_df = pd.DataFrame({
209+
'cui': per_cui_recall_merged.keys(),
210+
'recall_merged': per_cui_recall_merged.values(),
211+
'recall': per_cui_recall.values(),
212+
'precision': per_cui_prec.values(),
213+
'label_count': per_cui_anno_counts.values()}, index=[deid_cdb.cui2preferred_name[k] for k in per_cui_recall_merged])
214+
215+
return res_df, per_cui_annos_missed
216+

0 commit comments

Comments
 (0)