Skip to content

Commit 98302e1

Browse files
author
Tom Searle
committed
CU-86995ddec: Extra args on the train to pass in dataset to be split, or train / test files
1 parent c4c92ec commit 98302e1

File tree

3 files changed

+68
-15
lines changed

3 files changed

+68
-15
lines changed

medcat-v1/medcat/datasets/transformers_ner.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,12 @@ def _split_generators(self, dl_manager): # noqa
7171
"filepaths": self.config.data_files['train'],
7272
},
7373
),
74+
datasets.SplitGenerator(
75+
name=datasets.Split.TEST,
76+
gen_kwargs={
77+
"filepaths": self.config.data_files['test'],
78+
},
79+
),
7480
]
7581

7682
def _generate_examples(self, filepaths): # noqa

medcat-v1/medcat/ner/transformers_ner.py

Lines changed: 53 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,9 @@ def train(self,
176176
ignore_extra_labels=False,
177177
dataset=None,
178178
meta_requirements=None,
179-
trainer_callbacks: Optional[List[TrainerCallback]]=None) -> Tuple:
179+
trainer_callbacks: Optional[List[TrainerCallback]]=None,
180+
train_json_path: str=None,
181+
test_json_path: str=None) -> Tuple:
180182
"""Train or continue training a model give a json_path containing a MedCATtrainer export. It will
181183
continue training if an existing model is loaded or start new training if the model is blank/new.
182184
@@ -186,21 +188,27 @@ def train(self,
186188
ignore_extra_labels:
187189
Makes only sense when an existing deid model was loaded and from the new data we want to ignore
188190
labels that did not exist in the old model.
189-
dataset: Defaults to None.
191+
dataset: Defaults to None. Will be split by self.config.general['test_size'] into train and test datasets.
190192
meta_requirements: Defaults to None
191193
trainer_callbacks (List[TrainerCallback]):
192194
A list of trainer callbacks for collecting metrics during the training at the client side. The
193195
transformers Trainer object will be passed in when each callback is called.
194-
196+
train_json_path (str): Defaults to None. If provided, will be used as the training dataset json_path to load from
197+
test_json_path (str): Defaults to None. If provided, will be used as the test dataset json_path to load from
195198
Returns:
196199
Tuple: The dataframe, examples, and the dataset
197200
"""
198201

199-
if dataset is None and json_path is not None:
202+
if dataset is None:
200203
# Load the medcattrainer export
201-
json_path = self._prepare_dataset(json_path, ignore_extra_labels=ignore_extra_labels,
204+
if json_path is not None:
205+
json_path = self._prepare_dataset(json_path, ignore_extra_labels=ignore_extra_labels,
202206
meta_requirements=meta_requirements, file_name='data_eval.json')
203-
# Load dataset
207+
elif test_json_path is not None and train_json_path is not None:
208+
train_json_path = self._prepare_dataset(train_json_path, ignore_extra_labels=ignore_extra_labels,
209+
meta_requirements=meta_requirements, file_name='data_train.json')
210+
test_json_path = self._prepare_dataset(test_json_path, ignore_extra_labels=ignore_extra_labels,
211+
meta_requirements=meta_requirements, file_name='data_test.json')
204212

205213
# NOTE: The following is for backwards comppatibility
206214
# in datasets==2.20.0 `trust_remote_code=True` must be explicitly
@@ -212,13 +220,21 @@ def train(self,
212220
ds_load_dataset = partial(datasets.load_dataset, trust_remote_code=True)
213221
else:
214222
ds_load_dataset = datasets.load_dataset
215-
dataset = ds_load_dataset(os.path.abspath(transformers_ner.__file__),
216-
data_files={'train': json_path}, # type: ignore
217-
split='train',
218-
cache_dir='/tmp/')
219-
# We split before encoding so the split is document level, as encoding
220-
#does the document splitting into max_seq_len
221-
dataset = dataset.train_test_split(test_size=self.config.general['test_size']) # type: ignore
223+
224+
if json_path:
225+
dataset = ds_load_dataset(os.path.abspath(transformers_ner.__file__),
226+
data_files={'train': json_path}, # type: ignore
227+
split='train',
228+
cache_dir='/tmp/')
229+
# We split before encoding so the split is document level, as encoding
230+
# does the document splitting into max_seq_len
231+
dataset = dataset.train_test_split(test_size=self.config.general['test_size']) # type: ignore
232+
elif train_json_path and test_json_path:
233+
dataset = ds_load_dataset(os.path.abspath(transformers_ner.__file__),
234+
data_files={'train': train_json_path, 'test': test_json_path}, # type: ignore
235+
cache_dir='/tmp/')
236+
else:
237+
raise ValueError("Either json_path or train_json_path and test_json_path must be provided when no dataset is provided")
222238

223239
# Update labelmap in case the current dataset has more labels than what we had before
224240
self.tokenizer.calculate_label_map(dataset['train'])
@@ -520,3 +536,27 @@ def __call__(self, doc: Doc) -> Doc:
520536
def func_has_kwarg(func: Callable, keyword: str):
521537
sig = inspect.signature(func)
522538
return keyword in sig.parameters
539+
540+
541+
if __name__ == "__main__":
542+
import json
543+
from copy import copy
544+
from medcat.utils.ner.deid import DeIdModel
545+
546+
mct_export = json.load(open('/Users/k1897038/Downloads/MedCAT_Export_With_Text_2025-03-28_18_49_30.json'))
547+
548+
train_set = {'projects': [copy(mct_export['projects'][0])]}
549+
train_set['projects'][0]['documents'] = mct_export['projects'][0]['documents'][0:8]
550+
551+
test_set = {'projects': [copy(mct_export['projects'][0])]}
552+
test_set['projects'][0]['documents'] = mct_export['projects'][0]['documents'][8:]
553+
554+
json.dump(train_set, open('train_set.json', 'w'))
555+
json.dump(test_set, open('test_set.json', 'w'))
556+
557+
train_json_path = 'train_set.json'
558+
test_json_path = 'test_set.json'
559+
deid_model_path = '/Users/k1897038/Documents/cogstack_docs/medcat_models/medcat_deid_model_691c3f6a6e5400e7.zip'
560+
deid_cat = DeIdModel.load_model_pack(deid_model_path)
561+
deid_cat.train(train_json_path=train_json_path, test_json_path=test_json_path)
562+

medcat-v1/medcat/utils/ner/deid.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,9 +62,16 @@ class DeIdModel(NerModel):
6262
def __init__(self, cat: CAT) -> None:
6363
self.cat = cat
6464

65-
def train(self, json_path: Union[str, list, None],
65+
def train(self, json_path: Union[str, list, None]=None,
66+
train_json_path: Union[str, list, None]=None,
67+
test_json_path: Union[str, list, None]=None,
6668
*args, **kwargs) -> Tuple[Any, Any, Any]:
67-
return super().train(json_path, *args, train_nr=0, **kwargs) # type: ignore
69+
assert not all([json_path, train_json_path, test_json_path]), \
70+
"Either json_path or train_json_path and test_json_path must be provided when no dataset is provided"
71+
72+
return super().train(json_path=json_path,
73+
train_json_path=train_json_path,
74+
test_json_path=test_json_path, *args, **kwargs) # type: ignore
6875

6976
def deid_text(self, text: str, redact: bool = False) -> str:
7077
"""Deidentify text and potentially redact information.

0 commit comments

Comments
 (0)