Skip to content

Commit 092729d

Browse files
author
Tom Searle
committed
CU-8698jzjj3: Add tests and respond to comments
1 parent 3713aa2 commit 092729d

File tree

3 files changed

+251
-53
lines changed

3 files changed

+251
-53
lines changed

medcat-v1/medcat/utils/ner/deid.py

Lines changed: 52 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -184,23 +184,25 @@ def _is_deid_model(cls, cat: CAT) -> bool:
184184
@classmethod
185185
def _get_reason_not_deid(cls, cat: CAT) -> str:
186186
if cat.vocab is not None:
187-
return "Has vocab"
187+
return "Has voc§ab"
188188
if len(cat._addl_ner) != 1:
189189
return f"Incorrect number of addl_ner: {len(cat._addl_ner)}"
190190
return ""
191191

192192

193-
def match_rules(rules: List[Tuple[str, str]], texts: List[str], cat: CAT) -> List[List[Dict]]:
193+
def match_rules(rules: List[Tuple[str, str]], texts: List[str], cui2preferred_name: Dict[str, str]) -> List[List[Dict]]:
194194
"""Match a set of rules - pat / cui combos as post processing labels.
195195
196196
Uses a cat DeID model for pretty name mapping.
197197
198198
Args:
199199
rules (List[Tuple[str, str]]): List of tuples of pattern and cui
200200
texts (List[str]): List of texts to match rules on
201-
cat (CAT): The CAT instance
201+
cui2preferred_name (Dict[str, str]): Dictionary of CUI to preferred name, likely to be cat.cdb.cui2preferred_name.
202202
203203
Examples:
204+
>>> cat = CAT.load_model_pack(model_pack_path)
205+
...
204206
>>> rules = [
205207
('(123) 456-7890', '134'),
206208
('1234567890', '134'),
@@ -214,7 +216,7 @@ def match_rules(rules: List[Tuple[str, str]], texts: List[str], cat: CAT) -> Lis
214216
'My phone number is 123.456.7890',
215217
'My phone number is 1234567890',
216218
]
217-
>>> matches = match_rules(rules, texts, cat)
219+
>>> matches = match_rules(rules, texts, cat.cdb.cui2preferred_name)
218220
219221
Returns:
220222
List[List[Dict]]: List of lists of predictions from `match_rules`
@@ -230,21 +232,20 @@ def match_rules(rules: List[Tuple[str, str]], texts: List[str], cat: CAT) -> Lis
230232
for match in text_matches:
231233
matches_in_text.append({
232234
'source_value': match.group(),
233-
'pretty_name': cat.cdb.cui2preferred_name[concept],
235+
'pretty_name': cui2preferred_name[concept],
234236
'start': match.start(),
235237
'end': match.end(),
236238
'cui': concept,
237-
'acc': 1.0,
238-
'soure_value': match.group(0)
239+
'acc': 1.0
239240
})
240241
rule_matches_per_text.append(matches_in_text)
241242
return rule_matches_per_text
242243

243244

244-
def merge_preds(model_preds_by_text: List[List[Dict]],
245-
rule_matches_per_text: List[List[Dict]],
246-
accept_preds: bool = True) -> List[List[Dict]]:
247-
"""Merge predictions from rule based and deID model predictions.
245+
def merge_all_preds(model_preds_by_text: List[List[Dict]],
246+
rule_matches_per_text: List[List[Dict]],
247+
accept_preds: bool = True) -> List[List[Dict]]:
248+
"""Conveniance method to merge predictions from rule based and deID model predictions.
248249
249250
Args:
250251
model_preds_by_text (List[Dict]): list of predictions from
@@ -255,56 +256,63 @@ def merge_preds(model_preds_by_text: List[List[Dict]],
255256
model_preds_by_text, over the rule matches if they overlap.
256257
Defaults to using model preds over rules.
257258
259+
Returns:
260+
List[List[Dict]]: List of lists of predictions from `merge_all_preds`
261+
"""
262+
assert len(model_preds_by_text) == len(rule_matches_per_text), \
263+
"model_preds_by_text and rule_matches_per_text must have the same length as they should be CAT.get_entities and match_rules outputs of the same text"
264+
return [merge_preds(model_preds_by_text[i], rule_matches_per_text[i], accept_preds) for i in range(len(model_preds_by_text))]
265+
266+
267+
def merge_preds(model_preds: List[Dict],
268+
rule_matches: List[Dict],
269+
accept_preds: bool = True) -> List[Dict]:
270+
"""Merge predictions from rule based and deID model predictions.
271+
272+
Args:
273+
model_preds (List[Dict]): predictions from `cat.get_entities()`
274+
rule_matches (List[Dict]): predictions from output of running `match_rules` on a text
275+
accept_preds (bool): uses the predicted label from the model,
276+
model_preds, over the rule matches if they overlap.
277+
Defaults to using model preds over rules.
278+
258279
Examples:
259-
>>> # a list of lists of predictions from `cat.get_entities()`
260-
>>> model_preds_by_text = [
280+
>>> # a list of predictions from `cat.get_entities()`
281+
>>> model_preds = [
261282
[
262283
{'cui': '134', 'start': 10, 'end': 20, 'acc': 1.0,
263284
'pretty_name': 'Phone Number'},
264285
{'cui': '134', 'start': 25, 'end': 35, 'acc': 1.0,
265286
'pretty_name': 'Phone Number'}
266287
]
267288
]
268-
>>> # a list of lists of predictions from `match_rules`
269-
>>> rule_matches_by_text = [
289+
>>> # a list of predictions from `match_rules`
290+
>>> rule_matches = [
270291
[
271292
{'cui': '134', 'start': 10, 'end': 20, 'acc': 1.0,
272293
'pretty_name': 'Phone Number'},
273294
{'cui': '134', 'start': 25, 'end': 35, 'acc': 1.0,
274295
'pretty_name': 'Phone Number'}
275296
]
276297
]
277-
>>> merged_preds = merge_preds(model_preds_by_text, rule_matches_by_text)
298+
>>> merged_preds = merge_preds(model_preds, rule_matches)
278299
279300
Returns:
280-
List[List[Dict]]: List of lists of predictions from `merge_preds`
301+
List[Dict]: List of predictions from `merge_preds`
281302
"""
282-
all_preds = []
283303
if accept_preds:
284-
labels1 = model_preds_by_text
285-
labels2 = rule_matches_per_text
304+
labels1 = model_preds
305+
labels2 = rule_matches
286306
else:
287-
labels1 = rule_matches_per_text
288-
labels2 = model_preds_by_text
289-
for matches_text1, matches_text2 in zip(labels1, labels2):
290-
# Function to check if two spans overlap
291-
def has_overlap(span1, span2):
292-
return not (span1['end'] <= span2['start'] or
293-
span2['end'] <= span1['start'])
294-
295-
# Mark model predictions that overlap with rule matches
296-
to_remove = set()
297-
for text_match1 in matches_text1:
298-
for i, text_match2 in enumerate(matches_text2):
299-
if has_overlap(text_match1, text_match2):
300-
to_remove.add(i)
301-
302-
# Keep only non-overlapping model predictions
303-
matches_text2 = [text_match for i, text_match in
304-
enumerate(matches_text2) if i not in to_remove]
305-
306-
# merge preds and sort on start
307-
merged_preds = matches_text1 + matches_text2
308-
merged_preds.sort(key=lambda x: x['start'])
309-
all_preds.append(merged_preds)
310-
return all_preds
307+
labels1 = rule_matches
308+
labels2 = model_preds
309+
310+
# Keep only non-overlapping model predictions
311+
labels2 = [span2 for span2 in labels2
312+
if not any(not (span2['end'] <= span1['start'] or span1['end'] <= span2['start'])
313+
for span1 in labels1)]
314+
# merge preds and sort on start
315+
merged_preds = labels1 + labels2
316+
merged_preds.sort(key=lambda x: x['start'])
317+
merged_preds
318+
return merged_preds

medcat-v1/medcat/utils/ner/metrics.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,6 @@
66
from scipy.special import softmax
77
import logging
88

9-
from medcat.cdb import CDB
10-
119

1210
logger = logging.getLogger(__name__)
1311

@@ -32,7 +30,7 @@ def metrics(p, return_df=False, plus_recall=0, tokenizer=None, dataset=None, mer
3230
Returns:
3331
Dict: A dictionary of metrics.
3432
"""
35-
"""TODO: This could be done better, for sure. But it works.""" # noqa
33+
3634
predictions = np.array(p.predictions)
3735
predictions = softmax(predictions, axis=2)
3836
examples = None
@@ -154,7 +152,7 @@ def _anno_within_pred_list(label: Dict, preds: List[Dict]) -> bool:
154152
return any(label['start'] >= p['start'] and label['end'] <= p['end'] for p in preds)
155153

156154

157-
def evaluate_predictions(true_annotations: List[List[Dict]], all_preds: List[List[Dict]], texts: List[str], deid_cdb: CDB):
155+
def evaluate_predictions(true_annotations: List[List[Dict]], all_preds: List[List[Dict]], texts: List[str], cui2preferred_name: Dict[str, str]):
158156
"""
159157
Evaluate predictions against sets of collected labels as collected and output from a MedCATTrainer project.
160158
Counts predictions as correct if the prediction fully encloses the label.
@@ -163,7 +161,7 @@ def evaluate_predictions(true_annotations: List[List[Dict]], all_preds: List[Lis
163161
true_annotations (List[List[Dict]]): Ground truth predictions by text
164162
all_preds (List[List[Dict]]): Model predictions by text
165163
texts (List[str]): Original list of texts
166-
deid_cdb (CDB): Concept database
164+
cui2preferred_name (Dict[str, str]): Dictionary of CUI to preferred name, likely to be cat.cdb.cui2preferred_name.
167165
168166
Returns:
169167
Tuple[pd.DataFrame, Dict]: A tuple containing a DataFrame of evaluation metrics and a dictionary of missed annotations per CUI.
@@ -219,6 +217,6 @@ def evaluate_predictions(true_annotations: List[List[Dict]], all_preds: List[Lis
219217
'recall_merged': per_cui_recall_merged.values(),
220218
'recall': per_cui_recall.values(),
221219
'precision': per_cui_prec.values(),
222-
'label_count': per_cui_anno_counts.values()}, index=[deid_cdb.cui2preferred_name[k] for k in per_cui_recall_merged])
220+
'label_count': per_cui_anno_counts.values()}, index=[cui2preferred_name[k] for k in per_cui_recall_merged])
223221

224222
return res_df, per_cui_annos_missed

0 commit comments

Comments
 (0)