@@ -176,7 +176,9 @@ def train(self,
176176 ignore_extra_labels = False ,
177177 dataset = None ,
178178 meta_requirements = None ,
179- trainer_callbacks : Optional [List [TrainerCallback ]]= None ) -> Tuple :
179+ trainer_callbacks : Optional [List [TrainerCallback ]]= None ,
180+ train_json_path : str = None ,
181+ test_json_path : str = None ) -> Tuple :
180182 """Train or continue training a model give a json_path containing a MedCATtrainer export. It will
181183 continue training if an existing model is loaded or start new training if the model is blank/new.
182184
@@ -186,21 +188,27 @@ def train(self,
186188 ignore_extra_labels:
187189 Makes only sense when an existing deid model was loaded and from the new data we want to ignore
188190 labels that did not exist in the old model.
189- dataset: Defaults to None.
191+ dataset: Defaults to None. Will be split by self.config.general['test_size'] into train and test datasets.
190192 meta_requirements: Defaults to None
191193 trainer_callbacks (List[TrainerCallback]):
192194 A list of trainer callbacks for collecting metrics during the training at the client side. The
193195 transformers Trainer object will be passed in when each callback is called.
194-
196+ train_json_path (str): Defaults to None. If provided, will be used as the training dataset json_path to load from
197+ test_json_path (str): Defaults to None. If provided, will be used as the test dataset json_path to load from
195198 Returns:
196199 Tuple: The dataframe, examples, and the dataset
197200 """
198201
199- if dataset is None and json_path is not None :
202+ if dataset is None :
200203 # Load the medcattrainer export
201- json_path = self ._prepare_dataset (json_path , ignore_extra_labels = ignore_extra_labels ,
204+ if json_path is not None :
205+ json_path = self ._prepare_dataset (json_path , ignore_extra_labels = ignore_extra_labels ,
202206 meta_requirements = meta_requirements , file_name = 'data_eval.json' )
203- # Load dataset
207+ elif test_json_path is not None and train_json_path is not None :
208+ train_json_path = self ._prepare_dataset (train_json_path , ignore_extra_labels = ignore_extra_labels ,
209+ meta_requirements = meta_requirements , file_name = 'data_train.json' )
210+ test_json_path = self ._prepare_dataset (test_json_path , ignore_extra_labels = ignore_extra_labels ,
211+ meta_requirements = meta_requirements , file_name = 'data_test.json' )
204212
205213 # NOTE: The following is for backwards comppatibility
206214 # in datasets==2.20.0 `trust_remote_code=True` must be explicitly
@@ -212,13 +220,21 @@ def train(self,
212220 ds_load_dataset = partial (datasets .load_dataset , trust_remote_code = True )
213221 else :
214222 ds_load_dataset = datasets .load_dataset
215- dataset = ds_load_dataset (os .path .abspath (transformers_ner .__file__ ),
216- data_files = {'train' : json_path }, # type: ignore
217- split = 'train' ,
218- cache_dir = '/tmp/' )
219- # We split before encoding so the split is document level, as encoding
220- #does the document splitting into max_seq_len
221- dataset = dataset .train_test_split (test_size = self .config .general ['test_size' ]) # type: ignore
223+
224+ if json_path :
225+ dataset = ds_load_dataset (os .path .abspath (transformers_ner .__file__ ),
226+ data_files = {'train' : json_path }, # type: ignore
227+ split = 'train' ,
228+ cache_dir = '/tmp/' )
229+ # We split before encoding so the split is document level, as encoding
230+ # does the document splitting into max_seq_len
231+ dataset = dataset .train_test_split (test_size = self .config .general ['test_size' ]) # type: ignore
232+ elif train_json_path and test_json_path :
233+ dataset = ds_load_dataset (os .path .abspath (transformers_ner .__file__ ),
234+ data_files = {'train' : train_json_path , 'test' : test_json_path }, # type: ignore
235+ cache_dir = '/tmp/' )
236+ else :
237+ raise ValueError ("Either json_path or train_json_path and test_json_path must be provided when no dataset is provided" )
222238
223239 # Update labelmap in case the current dataset has more labels than what we had before
224240 self .tokenizer .calculate_label_map (dataset ['train' ])
@@ -520,3 +536,27 @@ def __call__(self, doc: Doc) -> Doc:
520536def func_has_kwarg (func : Callable , keyword : str ):
521537 sig = inspect .signature (func )
522538 return keyword in sig .parameters
539+
540+
541+ if __name__ == "__main__" :
542+ import json
543+ from copy import copy
544+ from medcat .utils .ner .deid import DeIdModel
545+
546+ mct_export = json .load (open ('/Users/k1897038/Downloads/MedCAT_Export_With_Text_2025-03-28_18_49_30.json' ))
547+
548+ train_set = {'projects' : [copy (mct_export ['projects' ][0 ])]}
549+ train_set ['projects' ][0 ]['documents' ] = mct_export ['projects' ][0 ]['documents' ][0 :8 ]
550+
551+ test_set = {'projects' : [copy (mct_export ['projects' ][0 ])]}
552+ test_set ['projects' ][0 ]['documents' ] = mct_export ['projects' ][0 ]['documents' ][8 :]
553+
554+ json .dump (train_set , open ('train_set.json' , 'w' ))
555+ json .dump (test_set , open ('test_set.json' , 'w' ))
556+
557+ train_json_path = 'train_set.json'
558+ test_json_path = 'test_set.json'
559+ deid_model_path = '/Users/k1897038/Documents/cogstack_docs/medcat_models/medcat_deid_model_691c3f6a6e5400e7.zip'
560+ deid_cat = DeIdModel .load_model_pack (deid_model_path )
561+ deid_cat .train (train_json_path = train_json_path , test_json_path = test_json_path )
562+
0 commit comments