@@ -64,14 +64,10 @@ def __init__(self, cat: CAT) -> None:
6464 self .cat = cat
6565
6666 def train (self , json_path : Union [str , list , None ]= None ,
67- train_json_path : Union [str , list , None ]= None ,
68- test_json_path : Union [str , list , None ]= None ,
6967 * args , ** kwargs ) -> Tuple [Any , Any , Any ]:
70- assert not all ([json_path , train_json_path , test_json_path ]), \
68+ assert not all ([json_path , kwargs . get ( ' train_json_path' ), kwargs . get ( ' test_json_path' ) ]), \
7169 "Either json_path or train_json_path and test_json_path must be provided when no dataset is provided"
72- return super ().train (json_path = json_path ,
73- train_json_path = train_json_path ,
74- test_json_path = test_json_path , * args , ** kwargs ) # type: ignore
70+ return super ().train (json_path = json_path , * args , ** kwargs ) # type: ignore
7571
7672 def eval (self , json_path : Union [str , list , None ],
7773 * args , ** kwargs ) -> Tuple [Any , Any , Any ]:
@@ -237,7 +233,7 @@ def match_rules(rules: List[Tuple[str, str]], texts: List[str], cat: CAT):
237233 return rule_matches_per_text
238234
239235
240- def merge_preds (model_preds_by_text : List [Dict ], rule_matches_per_text : List [Dict ], accept_preds = True ):
236+ def merge_preds (model_preds_by_text : List [List [ Dict ]] , rule_matches_per_text : List [List [ Dict ] ], accept_preds = True ):
241237 """
242238 Merge predictions from rule based and deID model predictions for further evaluation
243239
0 commit comments