Skip to content

Commit 12af3e3

Browse files
author
Tom Searle
committed
CU-8698jzjj3: fix mypy errors
1 parent b97b68f commit 12af3e3

File tree

3 files changed

+7
-5
lines changed

3 files changed

+7
-5
lines changed

medcat-v1/medcat/ner/transformers_ner.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -177,8 +177,8 @@ def train(self,
177177
ignore_extra_labels=False,
178178
dataset=None,
179179
meta_requirements=None,
180-
train_json_path: str=None,
181-
test_json_path: str=None,
180+
train_json_path: Union[str, list, None]=None,
181+
test_json_path: Union[str, list, None]=None,
182182
trainer_callbacks: Optional[List[Callable[[Trainer], TrainerCallback]]] = None) -> Tuple:
183183
"""Train or continue training a model give a json_path containing a MedCATtrainer export. It will
184184
continue training if an existing model is loaded or start new training if the model is blank/new.
@@ -571,5 +571,6 @@ def func_has_kwarg(func: Callable, keyword: str):
571571
test_json_path = 'test_set.json'
572572
deid_model_path = '/Users/k1897038/Documents/cogstack_docs/medcat_models/medcat_deid_model_691c3f6a6e5400e7.zip'
573573
deid_cat = DeIdModel.load_model_pack(deid_model_path)
574-
deid_cat.train(train_json_path=train_json_path, test_json_path=test_json_path)
574+
deid_cat.train(ignore_extra_labels=True, train_json_path=train_json_path, test_json_path=test_json_path)
575+
575576

medcat-v1/medcat/utils/ner/deid.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,6 @@ def train(self, json_path: Union[str, list, None]=None,
6969
*args, **kwargs) -> Tuple[Any, Any, Any]:
7070
assert not all([json_path, train_json_path, test_json_path]), \
7171
"Either json_path or train_json_path and test_json_path must be provided when no dataset is provided"
72-
7372
return super().train(json_path=json_path,
7473
train_json_path=train_json_path,
7574
test_json_path=test_json_path, *args, **kwargs) # type: ignore

medcat-v1/medcat/utils/ner/model.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ def __init__(self, cat: CAT) -> None:
2626
self.cat = cat
2727

2828
def train(self, json_path: Union[str, list, None], train_nr: int = 0,
29+
train_json_path: Union[str, list, None]=None,
30+
test_json_path: Union[str, list, None]=None,
2931
*args, **kwargs) -> Tuple[Any, Any, Any]:
3032
"""Train the underlying transformers NER model.
3133
@@ -40,7 +42,7 @@ def train(self, json_path: Union[str, list, None], train_nr: int = 0,
4042
Returns:
4143
Tuple[Any, Any, Any]: df, examples, dataset
4244
"""
43-
return self.cat._addl_ner[train_nr].train(json_path, *args, **kwargs)
45+
return self.cat._addl_ner[train_nr].train(json_path, train_json_path=train_json_path, test_json_path=test_json_path, *args, **kwargs)
4446

4547
def eval(self, json_path: Union[str, list, None], train_nr: int = 0,
4648
*args, **kwargs) -> Tuple[Any, Any, Any]:

0 commit comments

Comments
 (0)