@@ -63,7 +63,7 @@ class DeIdModel(NerModel):
6363 def __init__ (self , cat : CAT ) -> None :
6464 self .cat = cat
6565
66- def train (self , json_path : Union [str , list , None ]= None ,
66+ def train (self , json_path : Union [str , list , None ] = None ,
6767 * args , ** kwargs ) -> Tuple [Any , Any , Any ]:
6868 assert not all ([json_path , kwargs .get ('train_json_path' ), kwargs .get ('test_json_path' )]), \
6969 "Either json_path or train_json_path and test_json_path must be provided when no dataset is provided"
@@ -149,7 +149,8 @@ def deid_multi_texts(self,
149149 return out
150150
151151 @classmethod
152- def load_model_pack (cls , model_pack_path : str , config : Optional [Dict ] = None ) -> 'DeIdModel' :
152+ def load_model_pack (cls , model_pack_path : str ,
153+ config : Optional [Dict ] = None ) -> 'DeIdModel' :
153154 """Load DeId model from model pack.
154155
155156 The method first loads the CAT instance.
@@ -167,7 +168,7 @@ def load_model_pack(cls, model_pack_path: str, config: Optional[Dict] = None) ->
167168 Returns:
168169 DeIdModel: The resulting DeI model.
169170 """
170- ner_model = NerModel .load_model_pack (model_pack_path ,config = config )
171+ ner_model = NerModel .load_model_pack (model_pack_path , config = config )
171172 cat = ner_model .cat
172173 if not cls ._is_deid_model (cat ):
173174 raise ValueError (
@@ -190,25 +191,25 @@ def _get_reason_not_deid(cls, cat: CAT) -> str:
190191
191192
192193def match_rules (rules : List [Tuple [str , str ]], texts : List [str ], cat : CAT ):
193- """
194- Match a set of rules - pat / cui combos as post processing labels, uses
195- a cat DeID model forp pretty name mapping
194+ """Match a set of rules - pat / cui combos as post processing labels.
195+
196+ Uses a cat DeID model for pretty name mapping.
196197
197198 Examples:
198- >>> rules = [
199- ('(123) 456-7890', '134'),
200- ('1234567890', '134'),
201- ('123.456.7890', '134'),
202- ('1234567890', '134'),
203- ('1234567890', '134'),
204- ]
205- >>> texts = [
206- 'My phone number is (123) 456-7890',
207- 'My phone number is 1234567890',
208- 'My phone number is 123.456.7890',
209- 'My phone number is 1234567890',
210- ]
211- >>> matches = match_rules(rules, texts, cat)
199+ >>> rules = [
200+ ('(123) 456-7890', '134'),
201+ ('1234567890', '134'),
202+ ('123.456.7890', '134'),
203+ ('1234567890', '134'),
204+ ('1234567890', '134'),
205+ ]
206+ >>> texts = [
207+ 'My phone number is (123) 456-7890',
208+ 'My phone number is 1234567890',
209+ 'My phone number is 123.456.7890',
210+ 'My phone number is 1234567890',
211+ ]
212+ >>> matches = match_rules(rules, texts, cat)
212213 """
213214 # Iterate through each text and pattern combination
214215 rule_matches_per_text = []
@@ -217,7 +218,6 @@ def match_rules(rules: List[Tuple[str, str]], texts: List[str], cat: CAT):
217218 for pattern , concept in rules :
218219 # Find all matches of current pattern in current text
219220 text_matches = re .finditer (pattern , text , flags = re .M )
220-
221221 # Add each match with its pattern and text info
222222 for match in text_matches :
223223 matches_in_text .append ({
@@ -233,31 +233,40 @@ def match_rules(rules: List[Tuple[str, str]], texts: List[str], cat: CAT):
233233 return rule_matches_per_text
234234
235235
236- def merge_preds (model_preds_by_text : List [List [Dict ]], rule_matches_per_text : List [List [Dict ]], accept_preds = True ):
237- """
238- Merge predictions from rule based and deID model predictions for further evaluation
236+ def merge_preds (model_preds_by_text : List [List [Dict ]],
237+ rule_matches_per_text : List [List [Dict ]],
238+ accept_preds : bool = True ):
239+ """Merge predictions from rule based and deID model predictions.
239240
240- Args:
241- model_preds_by_text (List[Dict]): list of predictions from `cat.get_entities()`, then `[list(m['entities'].values()) for m in model_preds]`
242- rule_matches_by_text (List[Dict]): list of predictions from output of running `match_rules`
243- accept_preds (bool): uses the predicted label from the model, model_preds_by_text, over the rule matches if they overlap. Defaults to using model preds over rules.
241+ Args:
242+ model_preds_by_text (List[Dict]): list of predictions from
243+ `cat.get_entities()`, then `[list(m['entities'].values()) for m in model_preds]`
244+ rule_matches_by_text (List[Dict]): list of predictions from output of
245+ running `match_rules`
246+ accept_preds (bool): uses the predicted label from the model,
247+ model_preds_by_text, over the rule matches if they overlap.
248+ Defaults to using model preds over rules.
244249
245250 Examples:
246- >>> # a list of lists of predictions from `cat.get_entities()`
247- >>> model_preds_by_text = [
248- [
249- {'cui': '134', 'start': 10, 'end': 20, 'acc': 1.0, 'pretty_name': 'Phone Number'},
250- {'cui': '134', 'start': 25, 'end': 35, 'acc': 1.0, 'pretty_name': 'Phone Number'}
251+ >>> # a list of lists of predictions from `cat.get_entities()`
252+ >>> model_preds_by_text = [
253+ [
254+ {'cui': '134', 'start': 10, 'end': 20, 'acc': 1.0,
255+ 'pretty_name': 'Phone Number'},
256+ {'cui': '134', 'start': 25, 'end': 35, 'acc': 1.0,
257+ 'pretty_name': 'Phone Number'}
258+ ]
251259 ]
252- ]
253- >>> # a list of lists of predictions from `match_rules`
254- >>> rule_matches_by_text = [
255- [
256- {'cui': '134', 'start': 10, 'end': 20, 'acc': 1.0, 'pretty_name': 'Phone Number'},
257- {'cui': '134', 'start': 25, 'end': 35, 'acc': 1.0, 'pretty_name': 'Phone Number'}
260+ >>> # a list of lists of predictions from `match_rules`
261+ >>> rule_matches_by_text = [
262+ [
263+ {'cui': '134', 'start': 10, 'end': 20, 'acc': 1.0,
264+ 'pretty_name': 'Phone Number'},
265+ {'cui': '134', 'start': 25, 'end': 35, 'acc': 1.0,
266+ 'pretty_name': 'Phone Number'}
267+ ]
258268 ]
259- ]
260- >>> merged_preds = merge_preds(model_preds_by_text, rule_matches_by_text)
269+ >>> merged_preds = merge_preds(model_preds_by_text, rule_matches_by_text)
261270 """
262271 all_preds = []
263272 if accept_preds :
@@ -266,28 +275,25 @@ def merge_preds(model_preds_by_text: List[List[Dict]], rule_matches_per_text: Li
266275 else :
267276 labels1 = rule_matches_per_text
268277 labels2 = model_preds_by_text
269-
270278 for matches_text1 , matches_text2 in zip (labels1 , labels2 ):
271279 # Function to check if two spans overlap
272280 def has_overlap (span1 , span2 ):
273- return not (span1 ['end' ] <= span2 ['start' ] or span2 ['end' ] <= span1 ['start' ])
274-
281+ return not (span1 ['end' ] <= span2 ['start' ] or
282+ span2 ['end' ] <= span1 ['start' ])
283+
275284 # Mark model predictions that overlap with rule matches
276-
277285 to_remove = set ()
278286 for text_match1 in matches_text1 :
279287 for i , text_match2 in enumerate (matches_text2 ):
280288 if has_overlap (text_match1 , text_match2 ):
281289 to_remove .add (i )
282-
290+
283291 # Keep only non-overlapping model predictions
284- matches_text2 = [text_match for i , text_match in enumerate (matches_text2 ) if i not in to_remove ]
285-
292+ matches_text2 = [text_match for i , text_match in
293+ enumerate (matches_text2 ) if i not in to_remove ]
294+
286295 # merge preds and sort on start
287296 merged_preds = matches_text1 + matches_text2
288297 merged_preds .sort (key = lambda x : x ['start' ])
289298 all_preds .append (merged_preds )
290299 return all_preds
291-
292-
293-
0 commit comments