Skip to content

Commit 3713aa2

Browse files
author
Tom Searle
committed
CU-8698jzjj3: datasets splits fix and extra test for extra named arg
1 parent ea7538b commit 3713aa2

File tree

2 files changed

+38
-10
lines changed

2 files changed

+38
-10
lines changed

medcat-v1/medcat/datasets/transformers_ner.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -64,21 +64,28 @@ def _info(self):
6464

6565
def _split_generators(self, dl_manager): # noqa
6666
"""Returns SplitGenerators.""" # noqa
67-
return [
67+
splits = [
6868
datasets.SplitGenerator(
6969
name=datasets.Split.TRAIN,
7070
gen_kwargs={
7171
"filepaths": self.config.data_files['train'],
7272
},
7373
),
74-
datasets.SplitGenerator(
75-
name=datasets.Split.TEST,
76-
gen_kwargs={
77-
"filepaths": self.config.data_files['test'],
78-
},
79-
),
8074
]
8175

76+
# Only add test split if test data files are provided
77+
if 'test' in self.config.data_files:
78+
splits.append(
79+
datasets.SplitGenerator(
80+
name=datasets.Split.TEST,
81+
gen_kwargs={
82+
"filepaths": self.config.data_files['test'],
83+
},
84+
)
85+
)
86+
87+
return splits
88+
8289
def _generate_examples(self, filepaths): # noqa
8390
cnt = 0
8491
for filepath in filepaths:

medcat-v1/tests/ner/test_transformers_ner.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,11 @@ def test_pipe(self):
3434

3535
def test_train(self):
3636
tracker = unittest.mock.Mock()
37+
3738
class _DummyCallback(TrainerCallback):
3839
def __init__(self, trainer) -> None:
3940
self._trainer = trainer
41+
4042
def on_epoch_end(self, *args, **kwargs) -> None:
4143
tracker.call()
4244

@@ -49,13 +51,32 @@ def on_epoch_end(self, *args, **kwargs) -> None:
4951
assert dataset["test"].num_rows == 12
5052
self.assertEqual(tracker.call.call_count, 2)
5153

54+
def test_train_with_test_file(self):
55+
tracker = unittest.mock.Mock()
56+
57+
class _DummyCallback(TrainerCallback):
58+
def __init__(self, trainer) -> None:
59+
self._trainer = trainer
60+
61+
def on_epoch_end(self, *args, **kwargs) -> None:
62+
tracker.call()
63+
64+
train_data = os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "resources", "deid_train_data.json")
65+
test_data = os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "resources", "deid_test_data.json")
66+
self.undertest.training_arguments.num_train_epochs = 1
67+
df, examples, dataset = self.undertest.train(train_json_path=train_data, test_json_path=test_data, trainer_callbacks=[_DummyCallback])
68+
assert "fp" in examples
69+
assert "fn" in examples
70+
assert dataset["train"].num_rows == 60
71+
self.assertEqual(tracker.call.call_count, 1)
72+
5273
def test_expand_model_with_concepts(self):
5374
original_num_labels = self.undertest.model.num_labels
54-
original_out_features = self.undertest.model.classifier.out_features
75+
original_out_features = self.undertest.model.classifier.out_features
5576
original_label_map_size = len(self.undertest.tokenizer.label_map)
5677
cui2preferred_name = {
57-
"concept_1" : "Preferred Name 1",
58-
"concept_2" : "Preferred Name 2",
78+
"concept_1": "Preferred Name 1",
79+
"concept_2": "Preferred Name 2",
5980
}
6081

6182
self.undertest.expand_model_with_concepts(cui2preferred_name)

0 commit comments

Comments
 (0)