Skip to content

Commit df0080c

Browse files
author
Tom Searle
committed
CU-8698jzjj3: flake8 fixes
1 parent 65c84dc commit df0080c

File tree

3 files changed

+79
-75
lines changed

3 files changed

+79
-75
lines changed

medcat-v1/medcat/ner/transformers_ner.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,7 @@ def train(self,
213213
train_json_path = self._prepare_dataset(train_json_path, ignore_extra_labels=ignore_extra_labels,
214214
meta_requirements=meta_requirements, file_name='data_train.json')
215215
test_json_path = self._prepare_dataset(test_json_path, ignore_extra_labels=ignore_extra_labels,
216-
meta_requirements=meta_requirements, file_name='data_test.json')
216+
meta_requirements=meta_requirements, file_name='data_test.json')
217217

218218
# NOTE: The following is for backwards comppatibility
219219
# in datasets==2.20.0 `trust_remote_code=True` must be explicitly
@@ -225,7 +225,7 @@ def train(self,
225225
ds_load_dataset = partial(datasets.load_dataset, trust_remote_code=True)
226226
else:
227227
ds_load_dataset = datasets.load_dataset
228-
228+
229229
if json_path:
230230
dataset = ds_load_dataset(os.path.abspath(transformers_ner.__file__),
231231
data_files={'train': json_path}, # type: ignore
@@ -235,8 +235,8 @@ def train(self,
235235
# does the document splitting into max_seq_len
236236
dataset = dataset.train_test_split(test_size=self.config.general['test_size']) # type: ignore
237237
elif train_json_path and test_json_path:
238-
dataset = ds_load_dataset(os.path.abspath(transformers_ner.__file__),
239-
data_files={'train': train_json_path, 'test': test_json_path}, # type: ignore
238+
dataset = ds_load_dataset(os.path.abspath(transformers_ner.__file__),
239+
data_files={'train': train_json_path, 'test': test_json_path}, # type: ignore
240240
cache_dir='/tmp/')
241241
else:
242242
raise ValueError("Either json_path or train_json_path and test_json_path must be provided when no dataset is provided")
@@ -248,8 +248,8 @@ def train(self,
248248
if self.model.num_labels != len(self.tokenizer.label_map):
249249
logger.warning("The dataset contains labels we've not seen before, model is being reinitialized")
250250
logger.warning("Model: {} vs Dataset: {}".format(self.model.num_labels, len(self.tokenizer.label_map)))
251-
self.model = AutoModelForTokenClassification.from_pretrained(self.config.general['model_name'],
252-
num_labels=len(self.tokenizer.label_map),
251+
self.model = AutoModelForTokenClassification.from_pretrained(self.config.general['model_name'],
252+
num_labels=len(self.tokenizer.label_map),
253253
ignore_mismatched_sizes=True)
254254
self.tokenizer.cui2name = {k:self.cdb.get_name(k) for k in self.tokenizer.label_map.keys()}
255255

@@ -290,7 +290,6 @@ def train(self,
290290
# NOTE: this shouldn't really happen, but we'll do this for type safety
291291
raise ValueError("Output path should not be None!")
292292
self.save(save_dir_path=os.path.join(output_dir, 'final_model'))
293-
294293
# Run an eval step and return metrics
295294
p = trainer.predict(encoded_dataset['test']) # type: ignore
296295
df, examples = metrics(p, return_df=True, tokenizer=self.tokenizer, dataset=encoded_dataset['test'])

medcat-v1/medcat/utils/ner/deid.py

Lines changed: 57 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ class DeIdModel(NerModel):
6363
def __init__(self, cat: CAT) -> None:
6464
self.cat = cat
6565

66-
def train(self, json_path: Union[str, list, None]=None,
66+
def train(self, json_path: Union[str, list, None] = None,
6767
*args, **kwargs) -> Tuple[Any, Any, Any]:
6868
assert not all([json_path, kwargs.get('train_json_path'), kwargs.get('test_json_path')]), \
6969
"Either json_path or train_json_path and test_json_path must be provided when no dataset is provided"
@@ -149,7 +149,8 @@ def deid_multi_texts(self,
149149
return out
150150

151151
@classmethod
152-
def load_model_pack(cls, model_pack_path: str, config: Optional[Dict] = None) -> 'DeIdModel':
152+
def load_model_pack(cls, model_pack_path: str,
153+
config: Optional[Dict] = None) -> 'DeIdModel':
153154
"""Load DeId model from model pack.
154155
155156
The method first loads the CAT instance.
@@ -167,7 +168,7 @@ def load_model_pack(cls, model_pack_path: str, config: Optional[Dict] = None) ->
167168
Returns:
168169
DeIdModel: The resulting DeI model.
169170
"""
170-
ner_model = NerModel.load_model_pack(model_pack_path,config=config)
171+
ner_model = NerModel.load_model_pack(model_pack_path, config=config)
171172
cat = ner_model.cat
172173
if not cls._is_deid_model(cat):
173174
raise ValueError(
@@ -190,25 +191,25 @@ def _get_reason_not_deid(cls, cat: CAT) -> str:
190191

191192

192193
def match_rules(rules: List[Tuple[str, str]], texts: List[str], cat: CAT):
193-
"""
194-
Match a set of rules - pat / cui combos as post processing labels, uses
195-
a cat DeID model forp pretty name mapping
194+
"""Match a set of rules - pat / cui combos as post processing labels.
195+
196+
Uses a cat DeID model for pretty name mapping.
196197
197198
Examples:
198-
>>> rules = [
199-
('(123) 456-7890', '134'),
200-
('1234567890', '134'),
201-
('123.456.7890', '134'),
202-
('1234567890', '134'),
203-
('1234567890', '134'),
204-
]
205-
>>> texts = [
206-
'My phone number is (123) 456-7890',
207-
'My phone number is 1234567890',
208-
'My phone number is 123.456.7890',
209-
'My phone number is 1234567890',
210-
]
211-
>>> matches = match_rules(rules, texts, cat)
199+
>>> rules = [
200+
('(123) 456-7890', '134'),
201+
('1234567890', '134'),
202+
('123.456.7890', '134'),
203+
('1234567890', '134'),
204+
('1234567890', '134'),
205+
]
206+
>>> texts = [
207+
'My phone number is (123) 456-7890',
208+
'My phone number is 1234567890',
209+
'My phone number is 123.456.7890',
210+
'My phone number is 1234567890',
211+
]
212+
>>> matches = match_rules(rules, texts, cat)
212213
"""
213214
# Iterate through each text and pattern combination
214215
rule_matches_per_text = []
@@ -217,7 +218,6 @@ def match_rules(rules: List[Tuple[str, str]], texts: List[str], cat: CAT):
217218
for pattern, concept in rules:
218219
# Find all matches of current pattern in current text
219220
text_matches = re.finditer(pattern, text, flags=re.M)
220-
221221
# Add each match with its pattern and text info
222222
for match in text_matches:
223223
matches_in_text.append({
@@ -233,31 +233,40 @@ def match_rules(rules: List[Tuple[str, str]], texts: List[str], cat: CAT):
233233
return rule_matches_per_text
234234

235235

236-
def merge_preds(model_preds_by_text: List[List[Dict]], rule_matches_per_text: List[List[Dict]], accept_preds=True):
237-
"""
238-
Merge predictions from rule based and deID model predictions for further evaluation
236+
def merge_preds(model_preds_by_text: List[List[Dict]],
237+
rule_matches_per_text: List[List[Dict]],
238+
accept_preds: bool = True):
239+
"""Merge predictions from rule based and deID model predictions.
239240
240-
Args:
241-
model_preds_by_text (List[Dict]): list of predictions from `cat.get_entities()`, then `[list(m['entities'].values()) for m in model_preds]`
242-
rule_matches_by_text (List[Dict]): list of predictions from output of running `match_rules`
243-
accept_preds (bool): uses the predicted label from the model, model_preds_by_text, over the rule matches if they overlap. Defaults to using model preds over rules.
241+
Args:
242+
model_preds_by_text (List[Dict]): list of predictions from
243+
`cat.get_entities()`, then `[list(m['entities'].values()) for m in model_preds]`
244+
rule_matches_by_text (List[Dict]): list of predictions from output of
245+
running `match_rules`
246+
accept_preds (bool): uses the predicted label from the model,
247+
model_preds_by_text, over the rule matches if they overlap.
248+
Defaults to using model preds over rules.
244249
245250
Examples:
246-
>>> # a list of lists of predictions from `cat.get_entities()`
247-
>>> model_preds_by_text = [
248-
[
249-
{'cui': '134', 'start': 10, 'end': 20, 'acc': 1.0, 'pretty_name': 'Phone Number'},
250-
{'cui': '134', 'start': 25, 'end': 35, 'acc': 1.0, 'pretty_name': 'Phone Number'}
251+
>>> # a list of lists of predictions from `cat.get_entities()`
252+
>>> model_preds_by_text = [
253+
[
254+
{'cui': '134', 'start': 10, 'end': 20, 'acc': 1.0,
255+
'pretty_name': 'Phone Number'},
256+
{'cui': '134', 'start': 25, 'end': 35, 'acc': 1.0,
257+
'pretty_name': 'Phone Number'}
258+
]
251259
]
252-
]
253-
>>> # a list of lists of predictions from `match_rules`
254-
>>> rule_matches_by_text = [
255-
[
256-
{'cui': '134', 'start': 10, 'end': 20, 'acc': 1.0, 'pretty_name': 'Phone Number'},
257-
{'cui': '134', 'start': 25, 'end': 35, 'acc': 1.0, 'pretty_name': 'Phone Number'}
260+
>>> # a list of lists of predictions from `match_rules`
261+
>>> rule_matches_by_text = [
262+
[
263+
{'cui': '134', 'start': 10, 'end': 20, 'acc': 1.0,
264+
'pretty_name': 'Phone Number'},
265+
{'cui': '134', 'start': 25, 'end': 35, 'acc': 1.0,
266+
'pretty_name': 'Phone Number'}
267+
]
258268
]
259-
]
260-
>>> merged_preds = merge_preds(model_preds_by_text, rule_matches_by_text)
269+
>>> merged_preds = merge_preds(model_preds_by_text, rule_matches_by_text)
261270
"""
262271
all_preds = []
263272
if accept_preds:
@@ -266,28 +275,25 @@ def merge_preds(model_preds_by_text: List[List[Dict]], rule_matches_per_text: Li
266275
else:
267276
labels1 = rule_matches_per_text
268277
labels2 = model_preds_by_text
269-
270278
for matches_text1, matches_text2 in zip(labels1, labels2):
271279
# Function to check if two spans overlap
272280
def has_overlap(span1, span2):
273-
return not (span1['end'] <= span2['start'] or span2['end'] <= span1['start'])
274-
281+
return not (span1['end'] <= span2['start'] or
282+
span2['end'] <= span1['start'])
283+
275284
# Mark model predictions that overlap with rule matches
276-
277285
to_remove = set()
278286
for text_match1 in matches_text1:
279287
for i, text_match2 in enumerate(matches_text2):
280288
if has_overlap(text_match1, text_match2):
281289
to_remove.add(i)
282-
290+
283291
# Keep only non-overlapping model predictions
284-
matches_text2 = [text_match for i, text_match in enumerate(matches_text2) if i not in to_remove]
285-
292+
matches_text2 = [text_match for i, text_match in
293+
enumerate(matches_text2) if i not in to_remove]
294+
286295
# merge preds and sort on start
287296
merged_preds = matches_text1 + matches_text2
288297
merged_preds.sort(key=lambda x: x['start'])
289298
all_preds.append(merged_preds)
290299
return all_preds
291-
292-
293-

medcat-v1/medcat/utils/ner/metrics.py

Lines changed: 16 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ def metrics(p, return_df=False, plus_recall=0, tokenizer=None, dataset=None, mer
116116
for key in _cr:
117117
cui = ilabel_map[key]
118118
p_merged = tp_all / (tp_all + fp_all) if (tp_all + fp_all) > 0 else 0
119-
data.append([cui, tokenizer.cui2name.get(cui, cui), _cr[key]['precision'],
119+
data.append([cui, tokenizer.cui2name.get(cui, cui), _cr[key]['precision'],
120120
_cr[key]['recall'], _cr[key]['f1-score'], _cr[key]['support'], _cr[key]['r_merged'], p_merged])
121121

122122
df = pd.DataFrame(data[1:], columns=data[0])
@@ -133,7 +133,7 @@ def metrics(p, return_df=False, plus_recall=0, tokenizer=None, dataset=None, mer
133133

134134
def _anno_within_pred_list(label: Dict, preds: List[Dict]) -> bool:
135135
"""
136-
Check if a label is within a list of predictions,
136+
Check if a label is within a list of predictions,
137137
138138
Args:
139139
label (Dict): an annotation likely from a MedCATTrainer project
@@ -147,9 +147,9 @@ def _anno_within_pred_list(label: Dict, preds: List[Dict]) -> bool:
147147

148148
def evaluate_predictions(true_annotations: List[List[Dict]], all_preds: List[List[Dict]], texts: List[str], deid_cdb: CDB):
149149
"""
150-
Evaluate predictions against sets of collected labels as collected and output from a MedCATTrainer project.
150+
Evaluate predictions against sets of collected labels as collected and output from a MedCATTrainer project.
151151
Counts predictions as correct if the prediction fully encloses the label.
152-
152+
153153
Args:
154154
true_annotations (List[List[Dict]]): Ground truth predictions by text
155155
all_preds (List[List[Dict]]): Model predictions by text
@@ -165,37 +165,37 @@ def evaluate_predictions(true_annotations: List[List[Dict]], all_preds: List[Lis
165165
per_cui_anno_counts = {}
166166
per_cui_annos_missed = defaultdict(list)
167167
uniq_labels = set([p['cui'] for ap in true_annotations for p in ap])
168-
168+
169169
for cui in uniq_labels:
170170
# annos in test set
171171
anno_count = sum([len([p for p in cui_annos if p['cui'] == cui]) for cui_annos in true_annotations])
172172
pred_counts = sum([len([p for p in d if p['cui'] == cui]) for d in all_preds])
173-
173+
174174
# print(anno_count)
175175
# print(pred_counts)
176-
176+
177177
# print(f'pred_count: {pred_counts}, anno_count:{anno_count}')
178178
per_cui_anno_counts[cui] = anno_count
179-
179+
180180
doc_annos_left, preds_left, doc_annos_left_any_cui = [], [], []
181-
181+
182182
for doc_preds, doc_labels, text in zip(all_preds, true_annotations, texts):
183183
# num of annos that are not found - recall
184-
cui_labels = [l for l in doc_labels if l['cui'] == cui]
185-
cui_doc_preds = [p for p in doc_preds if p['cui'] == cui]
186-
184+
cui_labels = [label for label in doc_labels if label['cui'] == cui]
185+
cui_doc_preds = [pred for pred in doc_preds if pred['cui'] == cui]
186+
187187
labels_not_found = [label for label in cui_labels if not _anno_within_pred_list(label, cui_doc_preds)]
188188
doc_annos_left.append(len(labels_not_found))
189-
189+
190190
# num of annos that are not found across any cui prediction - recall_merged
191191
any_labels_not_found = [label for label in cui_labels if not _anno_within_pred_list(label, doc_preds)]
192192
doc_annos_left_any_cui.append(len(any_labels_not_found))
193193

194194
per_cui_annos_missed[cui].append(any_labels_not_found)
195-
195+
196196
# num of preds that are incorrect - precision
197197
preds_left.append(len([label for label in cui_doc_preds if not _anno_within_pred_list(label, cui_labels)]))
198-
198+
199199
if anno_count != 0 and pred_counts != 0:
200200
per_cui_recall[cui] = (anno_count - sum(doc_annos_left)) / anno_count
201201
per_cui_recall_merged[cui] = (anno_count - sum(doc_annos_left_any_cui)) / anno_count
@@ -211,6 +211,5 @@ def evaluate_predictions(true_annotations: List[List[Dict]], all_preds: List[Lis
211211
'recall': per_cui_recall.values(),
212212
'precision': per_cui_prec.values(),
213213
'label_count': per_cui_anno_counts.values()}, index=[deid_cdb.cui2preferred_name[k] for k in per_cui_recall_merged])
214-
215-
return res_df, per_cui_annos_missed
216214

215+
return res_df, per_cui_annos_missed

0 commit comments

Comments
 (0)