@@ -184,23 +184,25 @@ def _is_deid_model(cls, cat: CAT) -> bool:
184184 @classmethod
185185 def _get_reason_not_deid (cls , cat : CAT ) -> str :
186186 if cat .vocab is not None :
187- return "Has vocab "
187+ return "Has voc§ab "
188188 if len (cat ._addl_ner ) != 1 :
189189 return f"Incorrect number of addl_ner: { len (cat ._addl_ner )} "
190190 return ""
191191
192192
193- def match_rules (rules : List [Tuple [str , str ]], texts : List [str ], cat : CAT ) -> List [List [Dict ]]:
193+ def match_rules (rules : List [Tuple [str , str ]], texts : List [str ], cui2preferred_name : Dict [ str , str ] ) -> List [List [Dict ]]:
194194 """Match a set of rules - pat / cui combos as post processing labels.
195195
196196 Uses a cat DeID model for pretty name mapping.
197197
198198 Args:
199199 rules (List[Tuple[str, str]]): List of tuples of pattern and cui
200200 texts (List[str]): List of texts to match rules on
201- cat (CAT ): The CAT instance
201+ cui2preferred_name (Dict[str, str] ): Dictionary of CUI to preferred name, likely to be cat.cdb.cui2preferred_name.
202202
203203 Examples:
204+ >>> cat = CAT.load_model_pack(model_pack_path)
205+ ...
204206 >>> rules = [
205207 ('(123) 456-7890', '134'),
206208 ('1234567890', '134'),
@@ -214,7 +216,7 @@ def match_rules(rules: List[Tuple[str, str]], texts: List[str], cat: CAT) -> Lis
214216 'My phone number is 123.456.7890',
215217 'My phone number is 1234567890',
216218 ]
217- >>> matches = match_rules(rules, texts, cat)
219+ >>> matches = match_rules(rules, texts, cat.cdb.cui2preferred_name )
218220
219221 Returns:
220222 List[List[Dict]]: List of lists of predictions from `match_rules`
@@ -230,21 +232,20 @@ def match_rules(rules: List[Tuple[str, str]], texts: List[str], cat: CAT) -> Lis
230232 for match in text_matches :
231233 matches_in_text .append ({
232234 'source_value' : match .group (),
233- 'pretty_name' : cat . cdb . cui2preferred_name [concept ],
235+ 'pretty_name' : cui2preferred_name [concept ],
234236 'start' : match .start (),
235237 'end' : match .end (),
236238 'cui' : concept ,
237- 'acc' : 1.0 ,
238- 'soure_value' : match .group (0 )
239+ 'acc' : 1.0
239240 })
240241 rule_matches_per_text .append (matches_in_text )
241242 return rule_matches_per_text
242243
243244
244- def merge_preds (model_preds_by_text : List [List [Dict ]],
245- rule_matches_per_text : List [List [Dict ]],
246- accept_preds : bool = True ) -> List [List [Dict ]]:
247- """Merge predictions from rule based and deID model predictions.
245+ def merge_all_preds (model_preds_by_text : List [List [Dict ]],
246+ rule_matches_per_text : List [List [Dict ]],
247+ accept_preds : bool = True ) -> List [List [Dict ]]:
248+ """Conveniance method to merge predictions from rule based and deID model predictions.
248249
249250 Args:
250251 model_preds_by_text (List[Dict]): list of predictions from
@@ -255,56 +256,63 @@ def merge_preds(model_preds_by_text: List[List[Dict]],
255256 model_preds_by_text, over the rule matches if they overlap.
256257 Defaults to using model preds over rules.
257258
259+ Returns:
260+ List[List[Dict]]: List of lists of predictions from `merge_all_preds`
261+ """
262+ assert len (model_preds_by_text ) == len (rule_matches_per_text ), \
263+ "model_preds_by_text and rule_matches_per_text must have the same length as they should be CAT.get_entities and match_rules outputs of the same text"
264+ return [merge_preds (model_preds_by_text [i ], rule_matches_per_text [i ], accept_preds ) for i in range (len (model_preds_by_text ))]
265+
266+
267+ def merge_preds (model_preds : List [Dict ],
268+ rule_matches : List [Dict ],
269+ accept_preds : bool = True ) -> List [Dict ]:
270+ """Merge predictions from rule based and deID model predictions.
271+
272+ Args:
273+ model_preds (List[Dict]): predictions from `cat.get_entities()`
274+ rule_matches (List[Dict]): predictions from output of running `match_rules` on a text
275+ accept_preds (bool): uses the predicted label from the model,
276+ model_preds, over the rule matches if they overlap.
277+ Defaults to using model preds over rules.
278+
258279 Examples:
259- >>> # a list of lists of predictions from `cat.get_entities()`
260- >>> model_preds_by_text = [
280+ >>> # a list of predictions from `cat.get_entities()`
281+ >>> model_preds = [
261282 [
262283 {'cui': '134', 'start': 10, 'end': 20, 'acc': 1.0,
263284 'pretty_name': 'Phone Number'},
264285 {'cui': '134', 'start': 25, 'end': 35, 'acc': 1.0,
265286 'pretty_name': 'Phone Number'}
266287 ]
267288 ]
268- >>> # a list of lists of predictions from `match_rules`
269- >>> rule_matches_by_text = [
289+ >>> # a list of predictions from `match_rules`
290+ >>> rule_matches = [
270291 [
271292 {'cui': '134', 'start': 10, 'end': 20, 'acc': 1.0,
272293 'pretty_name': 'Phone Number'},
273294 {'cui': '134', 'start': 25, 'end': 35, 'acc': 1.0,
274295 'pretty_name': 'Phone Number'}
275296 ]
276297 ]
277- >>> merged_preds = merge_preds(model_preds_by_text, rule_matches_by_text )
298+ >>> merged_preds = merge_preds(model_preds, rule_matches )
278299
279300 Returns:
280- List[List[ Dict]] : List of lists of predictions from `merge_preds`
301+ List[Dict]: List of predictions from `merge_preds`
281302 """
282- all_preds = []
283303 if accept_preds :
284- labels1 = model_preds_by_text
285- labels2 = rule_matches_per_text
304+ labels1 = model_preds
305+ labels2 = rule_matches
286306 else :
287- labels1 = rule_matches_per_text
288- labels2 = model_preds_by_text
289- for matches_text1 , matches_text2 in zip (labels1 , labels2 ):
290- # Function to check if two spans overlap
291- def has_overlap (span1 , span2 ):
292- return not (span1 ['end' ] <= span2 ['start' ] or
293- span2 ['end' ] <= span1 ['start' ])
294-
295- # Mark model predictions that overlap with rule matches
296- to_remove = set ()
297- for text_match1 in matches_text1 :
298- for i , text_match2 in enumerate (matches_text2 ):
299- if has_overlap (text_match1 , text_match2 ):
300- to_remove .add (i )
301-
302- # Keep only non-overlapping model predictions
303- matches_text2 = [text_match for i , text_match in
304- enumerate (matches_text2 ) if i not in to_remove ]
305-
306- # merge preds and sort on start
307- merged_preds = matches_text1 + matches_text2
308- merged_preds .sort (key = lambda x : x ['start' ])
309- all_preds .append (merged_preds )
310- return all_preds
307+ labels1 = rule_matches
308+ labels2 = model_preds
309+
310+ # Keep only non-overlapping model predictions
311+ labels2 = [span2 for span2 in labels2
312+ if not any (not (span2 ['end' ] <= span1 ['start' ] or span1 ['end' ] <= span2 ['start' ])
313+ for span1 in labels1 )]
314+ # merge preds and sort on start
315+ merged_preds = labels1 + labels2
316+ merged_preds .sort (key = lambda x : x ['start' ])
317+ merged_preds
318+ return merged_preds
0 commit comments