From 0e842b40b648d46f796274f86d7359eb70f27e76 Mon Sep 17 00:00:00 2001 From: pidoux7 Date: Wed, 15 May 2024 11:40:53 +0200 Subject: [PATCH 01/18] Ajout relations --- edsnlp/data/converters.py | 58 +++- edsnlp/pipes/__init__.py | 1 + edsnlp/pipes/misc/relations/__init__.py | 1 + edsnlp/pipes/misc/relations/factory.py | 17 ++ edsnlp/pipes/misc/relations/patterns.py | 17 ++ edsnlp/pipes/misc/relations/relations.py | 366 +++++++++++++++++++++++ pyproject.toml | 1 + tests/pipelines/misc/test_relations.py | 60 ++++ tests/resources/relations/relations.json | 40 +++ tests/resources/relations/text.ann | 82 +++++ tests/resources/relations/text.txt | 3 + 11 files changed, 642 insertions(+), 4 deletions(-) create mode 100644 edsnlp/pipes/misc/relations/__init__.py create mode 100644 edsnlp/pipes/misc/relations/factory.py create mode 100644 edsnlp/pipes/misc/relations/patterns.py create mode 100644 edsnlp/pipes/misc/relations/relations.py create mode 100644 tests/pipelines/misc/test_relations.py create mode 100644 tests/resources/relations/relations.json create mode 100644 tests/resources/relations/text.ann create mode 100644 tests/resources/relations/text.txt diff --git a/edsnlp/data/converters.py b/edsnlp/data/converters.py index c94cf5da1..65b158318 100644 --- a/edsnlp/data/converters.py +++ b/edsnlp/data/converters.py @@ -66,7 +66,7 @@ def validate_kwargs(converter, kwargs): return {**(d.pop(vd.v_kwargs_name, None) or {}), **d} -class SequenceStr: +class SequenceStr: @classmethod def __get_validators__(cls): yield cls.validate @@ -191,6 +191,9 @@ class StandoffDict2DocConverter: span_attributes : Optional[AttributesMappingArg] Mapping from BRAT attributes to Span extensions (can be a list too). By default, all attributes are imported as Span extensions with the same name. + span_rel : Optional[AttributesMappingArg] + Mapping from BRAT relations to Span extensions (can be a list too). + By default, all relations are imported as Span extensions with the name rel. keep_raw_attribute_values : bool Whether to keep the raw attribute values (as strings) or to convert them to Python objects (e.g. booleans). @@ -214,6 +217,7 @@ def __init__( tokenizer: Optional[Tokenizer] = None, span_setter: SpanSetterArg = {"ents": True, "*": True}, span_attributes: Optional[AttributesMappingArg] = None, + span_rel: Optional[AttributesMappingArg] = None, # à voir si on le garde keep_raw_attribute_values: bool = False, bool_attributes: SequenceStr = [], default_attributes: AttributesMappingArg = {}, @@ -223,6 +227,7 @@ def __init__( self.tokenizer = tokenizer or (nlp.tokenizer if nlp is not None else None) self.span_setter = span_setter self.span_attributes = span_attributes # type: ignore + self.span_rel = span_rel # à voir si on le garde self.keep_raw_attribute_values = keep_raw_attribute_values self.default_attributes = default_attributes self.notes_as_span_attribute = notes_as_span_attribute @@ -244,12 +249,17 @@ def __call__(self, obj): if not Span.has_extension(dst): Span.set_extension(dst, default=None) + ############## Modification pour les relations ############### + dict_entities={} ## dictionnaire pour stocker les entités for ent in obj.get("entities") or (): + begin = min(f["begin"] for f in ent["fragments"]) # debut de l'entité + end = max(f["end"] for f in ent["fragments"]) # fin de l'entité + dict_entities[ent['entity_id']] = ent['label'] + ';' + str(begin) + ';' + str(end) # stocker les entités fragments = ( [ { - "begin": min(f["begin"] for f in ent["fragments"]), - "end": max(f["end"] for f in ent["fragments"]), + "begin": begin, + "end": end, } ] if not self.split_fragments @@ -260,7 +270,7 @@ def __call__(self, obj): fragment["begin"], fragment["end"], label=ent["label"], - alignment_mode="expand", + alignment_mode="expand", # ajout id ) if self.notes_as_span_attribute and ent["notes"]: ent["attributes"][self.notes_as_span_attribute] = "|".join( @@ -290,6 +300,7 @@ def __call__(self, obj): span._.set(new_name, value) spans.append(span) + set_spans(doc, spans, span_setter=self.span_setter) for attr, value in self.default_attributes.items(): @@ -297,6 +308,44 @@ def __call__(self, obj): if span._.get(attr) is None: span._.set(attr, value) + + ############## Modification pour les relations ############### + # Ajout des relations en terme de span + if self.span_rel is None and not Span.has_extension( + 'rel' + ): + Span.set_extension('rel', default=[]) + + for rel in obj.get("relations") or (): # itere relation + for label in doc.spans: # itere label source + for i, spa in enumerate(doc.spans[label]): # itere spans source + bo = False + + #relations + if dict_entities[rel["from_entity_id"]].split(';') == [label , str(spa.start_char), str(spa.end_char)]: # si l'entité source est la meme que celle du span + for label2 in doc.spans: #itere label target + for j, spa2 in enumerate(doc.spans[label2]): #iter label target + if dict_entities[rel["to_entity_id"]].split(';') == [label2 , str(spa2.start_char), str(spa2.end_char)]: # si l'entité target est la meme que celle du span + relation = {'type': rel['relation_label'], 'target': doc.spans[label2][j]} # creer la relation + doc.spans[label][i]._.rel.append(relation) # ajouter la relation au span + bo = True + break + if bo == True: + break + bo = False + + # relations inverses + if dict_entities[rel["to_entity_id"]].split(';') == [label , str(spa.start_char), str(spa.end_char)]: + for label2 in doc.spans: + for j, spa2 in enumerate(doc.spans[label2]): + if dict_entities[rel["from_entity_id"]].split(';') == [label2 , str(spa2.start_char), str(spa2.end_char)]: + relation = {'type': 'inv_' + rel['relation_label'], 'target': doc.spans[label2][j]} + doc.spans[label][i]._.rel.append(relation) + bo=True + break + if bo == True: + break + return doc @@ -644,6 +693,7 @@ def get_dict2doc_converter( converter = edsnlp.registry.factory.get(filtered[0]) converter = converter(**kwargs).instantiate(nlp=None) kwargs = {} + print(converter, kwargs) return converter, kwargs except (KeyError, IndexError): available = [v for v in available if "dict2doc" in v] diff --git a/edsnlp/pipes/__init__.py b/edsnlp/pipes/__init__.py index dea40bc5c..33666f6ba 100644 --- a/edsnlp/pipes/__init__.py +++ b/edsnlp/pipes/__init__.py @@ -22,6 +22,7 @@ from .misc.dates.factory import create_component as dates from .misc.measurements.factory import create_component as measurements from .misc.reason.factory import create_component as reason + from .misc.relations.factory import create_component as relations from .misc.sections.factory import create_component as sections from .misc.tables.factory import create_component as tables from .ner.adicap.factory import create_component as adicap diff --git a/edsnlp/pipes/misc/relations/__init__.py b/edsnlp/pipes/misc/relations/__init__.py new file mode 100644 index 000000000..e5f49bc5a --- /dev/null +++ b/edsnlp/pipes/misc/relations/__init__.py @@ -0,0 +1 @@ +from .relations import RelationsMatcher \ No newline at end of file diff --git a/edsnlp/pipes/misc/relations/factory.py b/edsnlp/pipes/misc/relations/factory.py new file mode 100644 index 000000000..2a26f27ed --- /dev/null +++ b/edsnlp/pipes/misc/relations/factory.py @@ -0,0 +1,17 @@ +from edsnlp.core import registry + +from .relations import RelationsMatcher + +DEFAULT_CONFIG = dict( + scheme=None, + use_sentences=False, + clean_rel=False, + proximity_method = "right", + max_dist=45, +) + +create_component = registry.factory.register( + "eds.relations", + assigns=["doc.spans"], + deprecated=["relations"], +)(RelationsMatcher) \ No newline at end of file diff --git a/edsnlp/pipes/misc/relations/patterns.py b/edsnlp/pipes/misc/relations/patterns.py new file mode 100644 index 000000000..c4cb8fa13 --- /dev/null +++ b/edsnlp/pipes/misc/relations/patterns.py @@ -0,0 +1,17 @@ +scheme = [ + { + "subject": [{"label": "Chemical_and_drugs", "attr": {"Tech": [None]}}], + "object": [ + { + "label": "Temporal", + "attr": {"AttTemp": [None, "Duration", "Date", "Time"]}, + }, + { + "label": "Chemical_and_drugs", + "attr": {"Tech": ["dosage", "route", "strength", "form"]}, + }, + ], + "type": "Depend", + "inv_type": "inv_Depend", + }, + ] \ No newline at end of file diff --git a/edsnlp/pipes/misc/relations/relations.py b/edsnlp/pipes/misc/relations/relations.py new file mode 100644 index 000000000..9e53bbd6a --- /dev/null +++ b/edsnlp/pipes/misc/relations/relations.py @@ -0,0 +1,366 @@ +from typing import Dict, Iterable, List, Union + +from loguru import logger + +from spacy.tokens import Doc, Span +from typing import Any +from numpy.typing import NDArray + +from edsnlp.core import PipelineProtocol +from edsnlp.pipes.misc.relations import patterns +import math as m +import numpy as np +import json + + +class RelationsMatcher: + """ A spaCy EDSNLP pipeline component to find relations between entities based on their proximity. + scheme = [ + { + "subject": [{"label": "Chemical_and_drugs", "attr": {"Tech": [None]}}], + "object": [ + { + "label": "Temporal", + "attr": {"AttTemp": [None, "Duration", "Date", "Time"]}, + }, + { + "label": "Chemical_and_drugs", + "attr": {"Tech": ["dosage", "route", "strength", "form"]}, + }, + ], + "type": "Depend", + "inv_type": "inv_Depend", + }, + ] + """ + + def __init__( + self, + nlp: PipelineProtocol, + name: str = "relations", + *, + scheme: Union[Union[Dict, List[Dict]],str] = None, + use_sentences: bool = False, + proximity_method: str = "right", + clean_rel: bool = True, + max_dist: int = 45, + ): + self.nlp = nlp + if not isinstance(name, str): + raise ValueError("name must be a string") + self.name = name + + if scheme is None: + scheme = patterns.scheme + if isinstance(scheme, str): + #ouvrir le fichier json et le lire pour le mettre dans une variable + if not scheme.endswith(".json"): + raise ValueError("scheme must be a json file") + with open(scheme) as f: + scheme = json.load(f) + if isinstance(scheme, dict): + scheme = [scheme] + self.check_scheme(scheme) + self.scheme = scheme + + if not isinstance(use_sentences, bool): + raise ValueError("use_sentences must be a boolean") + self.use_sentences = use_sentences and ( + "eds.sentences" in nlp.pipe_names or "sentences" in nlp.pipe_names + ) + if use_sentences and not self.use_sentences: + logger.warning( + "You have requested that the pipeline use annotations " + "provided by the `eds.sentences` pipeline, but it was not set. " + "Skipping that step." + ) + + if proximity_method not in ["sym", "start", "end", "middle", "right", "left"]: + raise ValueError( + """proximity_method must be one of 'sym','start', + 'end', 'middle', 'right', 'left'""" + ) + self.proximity_method = proximity_method + + if not isinstance(clean_rel, bool): + raise ValueError("clean_rel must be a boolean") + self.clean_rel = clean_rel + + if not isinstance(max_dist, int): + raise ValueError("max_dist must be an integer") + self.max_dist = max_dist + + self.set_extensions() + + def check_scheme(self, schemes): + for scheme in schemes: + if not isinstance(scheme, dict): + raise ValueError("scheme must be a dictionary") + if "subject" not in scheme: + raise ValueError("scheme must contain a 'subject' key") + if "object" not in scheme: + raise ValueError("scheme must contain an 'object' key") + if "type" not in scheme: + raise ValueError("scheme must contain a 'type' key") + if "inv_type" not in scheme: + raise ValueError("scheme must contain an 'inv_type' key") + if not isinstance(scheme["subject"], list): + raise ValueError("scheme['subject'] must be a list") + if not isinstance(scheme["object"], list): + raise ValueError("scheme['object'] must be a list") + if not isinstance(scheme["type"], str): + raise ValueError("scheme['type'] must be a string") + if not isinstance(scheme["inv_type"], str): + raise ValueError("scheme['inv_type'] must be a string") + for sub in scheme["subject"]: + if not isinstance(sub, dict): + raise ValueError("scheme['subject'] must contain dictionaries") + if "label" not in sub: + raise ValueError("scheme['subject'] must contain a 'label' key") + if not isinstance(sub["label"], str): + raise ValueError("scheme['subject']['label'] must be a string") + if "attr" in sub: + if sub["attr"] is not None and not isinstance(sub["attr"], dict): + raise ValueError("scheme['subject']['attr'] must be a dictionary or None") + for obj in scheme["object"]: + if not isinstance(obj, dict): + raise ValueError("scheme['object'] must contain dictionaries") + if "label" not in obj: + raise ValueError("scheme['object'] must contain a 'label' key") + if not isinstance(obj["label"], str): + raise ValueError("scheme['object']['label'] must be a string") + if "attr" in obj: + if obj["attr"] is not None and not isinstance(obj["attr"], dict): + raise ValueError("scheme['object']['attr'] must be a dictionary or None") + return True + + @classmethod + def set_extensions(cls) -> None: + """Set the extension rel for the Span object. + """ + if not Span.has_extension("rel"): + Span.set_extension("rel", default=[]) + + def clean_relations(self, doc: Doc) -> Doc: + """Remove the relations from the doc + + Args: + doc (Doc): the doc to be processed + + Returns: + Doc: the doc with the relations removed + """ + for label, spans in doc.spans.items(): + for span in spans: + if span._.rel: + span._.rel = [] + return doc + + def __call__(self, doc: Doc) -> Doc: + """find the relations in the doc based on the proximity of the entities attributes + + Args: + doc (Doc): the doc to be processed + + Returns: + Doc: the doc with the relations added + """ + if self.clean_rel: + doc = self.clean_relations(doc) + + dict_r = self.find_relations(doc) + + for r, rel in enumerate(dict_r): + if len(dict_r[r]["mat_obj"]) > 0 and len(dict_r[r]["mat_sub"]) > 0: + min_distance_indices, distances = self.calculate_min_distances( + dict_r[r]["mat_sub"], dict_r[r]["mat_obj"] + ) + for i, span_obj in enumerate(dict_r[r]["spans_obj"]): + if distances[min_distance_indices[i]][i] <= self.max_dist: + span_sub = dict_r[r]["spans_sub"][min_distance_indices[i]] + if self.use_sentences and not self.sentences(doc, span_obj["span"], span_sub["span"]): + continue + doc.spans[span_obj['label']][span_obj['num_span']]._.rel.append( + {"type": dict_r[r]["inv_type"], "target": doc.spans[span_sub["label"]][span_sub["num_span"]]} + ) + + doc.spans[span_sub['label']][span_sub['num_span']]._.rel.append( + {"type": dict_r[r]["type"], "target": doc.spans[span_obj["label"]][span_obj["num_span"]]} + ) + return doc + + def sentences(self, doc: Doc, span_obj: Span, span_sub: Span) -> bool: + """ Check if span_obj and span_sub are in the same sentence. + + Args: + doc (Doc): EDSNLP Doc object + span_obj (Span): span representing the target + span_sub (Span): span representing the source + + Returns: + bool: True if span_obj and span_sub are in the same sentence, False otherwise. + """ + for sent in doc.sents: + if span_obj.start >= sent.start and span_obj.end <= sent.end and span_sub.start >= sent.start and span_sub.end <= sent.end: + return True + return False + + + def find_relations(self, doc: Doc) -> Dict: + """ + Detect the potential subjects and objects in the document + + Args: + doc (Doc): EDSNLP Doc object + + Returns: + Dict: dict containing the potential subjects and objects + """ + dict_r = {} + for r, relation in enumerate(self.scheme): + dict_r[r] = { + "mat_obj": [], + "spans_obj": [], + "mat_sub": [], + "spans_sub": [], + "type": relation["type"], + "inv_type": relation["inv_type"], + } + # Treatment of objects + for obj in relation["object"]: + label_obj = obj["label"] + attr_obj = obj["attr"] + if label_obj in doc.spans: + if attr_obj is not None: + for num_span_obj, span_obj in enumerate(doc.spans[label_obj]): + if self.filter_spans( + span_obj, *list(attr_obj.items())[0], label_obj + ): + dict_r[r]["mat_obj"].append( + [span_obj.start_char, span_obj.end_char] + ) + dict_r[r]["spans_obj"].append({'label': label_obj, 'num_span': num_span_obj, 'span': span_obj}) + else: + for num_span_obj, span_obj in enumerate(doc.spans[label_obj]): + dict_r[r]["mat_obj"].append( + [span_obj.start_char, span_obj.end_char] + ) + dict_r[r]["spans_obj"].append({'label': label_obj, 'num_span': num_span_obj, 'span': span_obj}) + + # Treatment of subjects + for sub in relation["subject"]: + label_sub = sub["label"] + attr_sub = sub["attr"] + if label_sub in doc.spans: + if attr_sub is not None: + for num_span_sub, span_sub in enumerate(doc.spans[label_sub]): + if self.filter_spans( + span_sub, *list(attr_sub.items())[0], label_sub + ): + dict_r[r]["mat_sub"].append( + [span_sub.start_char, span_sub.end_char] + ) + dict_r[r]["spans_sub"].append({'label': label_sub, 'num_span': num_span_sub, 'span': span_sub}) + else: + for num_span_sub, span_sub in enumerate(doc.spans[label_sub]): + dict_r[r]["mat_sub"].append( + [span_sub.start_char, span_sub.end_char] + ) + dict_r[r]["spans_sub"].append({'label': label_sub, 'num_span': num_span_sub, 'span': span_sub}) + + # Convert lists to numpy arrays for easier manipulation later + for r in dict_r: + dict_r[r]["mat_obj"] = np.array(dict_r[r]["mat_obj"]) + dict_r[r]["mat_sub"] = np.array(dict_r[r]["mat_sub"]) + + return dict_r + + def calculate_min_distances(self, subjects:NDArray[Any], objects:NDArray[Any]) -> NDArray[Any]: + """ calculate the minimum distance between subjects and objects + + Args: + subjects (NDArray[Any]): entities to be used as subjects + objects (NDArray[Any]): entities to be used as objects + + Returns: + NDArray[Any]: the index of the subject that is closest to the object + """ + + subjects_expanded = subjects[:, np.newaxis, :] + objects_expanded = objects[np.newaxis, :, :] + + # calculate the distances between the entities + distance_start_to_end = subjects_expanded[:, :, 0] - objects_expanded[:, :, 1] + distance_end_to_start = objects_expanded[:, :, 0] - subjects_expanded[:, :, 1] + distance_start_to_start = subjects_expanded[:, :, 0] - objects_expanded[:, :, 0] + distance_end_to_end = objects_expanded[:, :, 1] - subjects_expanded[:, :, 1] + distance_middle = ( + subjects_expanded[:, :, 0] + subjects_expanded[:, :, 1] + ) / 2 - (objects_expanded[:, :, 0] + objects_expanded[:, :, 1]) / 2 + + if self.proximity_method == "sym": + distance_middle = np.abs( + np.minimum(distance_start_to_start, distance_end_to_end) + ) + # determine the mask for the left and right side of the entities + mask_left = objects_expanded[:, :, 1] <= subjects_expanded[:, :, 0] + mask_right = subjects_expanded[:, :, 1] <= objects_expanded[:, :, 0] + + # Assign the distances based on the mask + distances = np.where(mask_left, np.abs(distance_start_to_end), np.inf) + distances = np.where(mask_right, np.abs(distance_end_to_start), distances) + mask_middle = np.logical_not(np.logical_or(mask_left, mask_right)) + distances = np.where(mask_middle, distance_middle, distances) + distances = np.where(distances == 0, np.inf, distances) + + # find the index of the minimum distance + min_distance_indices = np.argmin(distances, axis=0) + + if self.proximity_method == "right": + distances = np.abs(distance_end_to_start) + distances = np.where(distances == 0, np.inf, distances) + min_distance_indices = np.argmin(distances, axis=0) + + if self.proximity_method == "left": + distances = np.abs(distance_start_to_end) + distances = np.where(distances == 0, np.inf, distances) + min_distance_indices = np.argmin(distances, axis=0) + + if self.proximity_method == "start": + distances = np.abs(distance_start_to_start) + distances = np.where(distances == 0, np.inf, distances) + min_distance_indices = np.argmin(distances, axis=0) + + if self.proximity_method == "end": + distances = np.abs(distance_end_to_end) + distances = np.where(distances == 0, np.inf, distances) + min_distance_indices = np.argmin(distances, axis=0) + + if self.proximity_method == "middle": + distances = distance_middle + distances = np.where(distances == 0, np.inf, distances) + min_distance_indices = np.argmin(distances, axis=0) + + return min_distance_indices, distances + + def filter_spans( + self, span: Span, attr_name: str, attr_values: list, label: str + ) -> bool: + """Filter the spans based on the attribute values + + Args: + span (Span): the span to be filtered + attr_name (str): the name of the attribute + attr_values (list): the values of the attribute + label (str): the label of the span + + Returns: + bool: _description_ + """ + # Get the attribute value or None if it doesn't exist + attr_value = getattr(span._, attr_name, None) + if attr_value in attr_values: + return True + return False + + diff --git a/pyproject.toml b/pyproject.toml index 9b347117a..47a810890 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -173,6 +173,7 @@ where = ["."] "eds.reason" = "edsnlp.pipes.misc.reason.factory:create_component" "eds.sections" = "edsnlp.pipes.misc.sections.factory:create_component" "eds.tables" = "edsnlp.pipes.misc.tables.factory:create_component" +"eds.relations" = "edsnlp.pipes.misc.relations.factory:create_component" # Deprecated (links to the same factories as above) "SOFA" = "edsnlp.pipes.ner.scores.sofa.factory:create_component" diff --git a/tests/pipelines/misc/test_relations.py b/tests/pipelines/misc/test_relations.py new file mode 100644 index 000000000..8a8437818 --- /dev/null +++ b/tests/pipelines/misc/test_relations.py @@ -0,0 +1,60 @@ +import sys +import os +from pytest import mark +import pytest +from spacy.tokens import Doc, Span + +# Assurez-vous que le chemin vers edsnlp est en premier dans sys.path +sys.path.insert(0, "/home/pidoux/edsnlp") + +# Importation des modules après avoir ajouté le chemin +import edsnlp +import edsnlp.pipes.core as eds + +@mark.parametrize("use_sentences", [True, False]) +@mark.parametrize("clean_rel", [True, False]) +@mark.parametrize("proximity_method", ["sym", "right", "left", "middle", "start", "end"]) +@mark.parametrize("max_dist", [1, 40, 100]) +def test_relations(use_sentences, clean_rel, proximity_method, max_dist): + dossier = "../../resources/relations/" + doc_iterator = edsnlp.data.read_standoff(dossier) + corpus = list(doc_iterator) + assert len(corpus) > 0 + for doc in corpus: + assert isinstance(doc, Doc) + for label in doc.spans: + for span in doc.spans[label]: + assert isinstance(span, Span) + assert span.has_extension('rel') + for rel in span._.rel: + assert isinstance(rel['target'], Span) + assert isinstance(rel['type'], str) + assert rel['type'] == 'Depend' or rel['type'] == 'inv_Depend' + + nlp = edsnlp.blank("eds") + nlp.add_pipe("eds.sentences") + nlp.add_pipe( + "eds.relations", + config={ + "scheme": os.path.join(dossier, "relations.json"), + "use_sentences": use_sentences, + "clean_rel": clean_rel, + "proximity_method": proximity_method, + "max_dist": max_dist, + }, + ) + + doc = nlp(corpus[0]) + + for label in doc.spans: + for span in doc.spans[label]: + print(span, span._.rel) + assert isinstance(span, Span) + assert span.has_extension('rel') + for rel in span._.rel: + assert isinstance(rel['target'], Span) + assert isinstance(rel['type'], str) + assert rel['type'] == 'Depend' or rel['type'] == 'inv_Depend' + +if __name__ == "__main__": + pytest.main() diff --git a/tests/resources/relations/relations.json b/tests/resources/relations/relations.json new file mode 100644 index 000000000..a3d77fe13 --- /dev/null +++ b/tests/resources/relations/relations.json @@ -0,0 +1,40 @@ +[ + { + "subject": [ + { + "label": "Chemical_and_drugs", + "attr": { + "Tech": [ + null + ] + } + } + ], + "object": [ + { + "label": "Temporal", + "attr": { + "AttTemp": [ + "Duration", + "Date", + "Frequency", + "Time" + ] + } + }, + { + "label": "Chemical_and_drugs", + "attr": { + "Tech": [ + "dosage", + "route", + "strength", + "form" + ] + } + } + ], + "type": "Depend", + "inv_type": "inv_Depend" + } +] diff --git a/tests/resources/relations/text.ann b/tests/resources/relations/text.ann new file mode 100644 index 000000000..abc346ad2 --- /dev/null +++ b/tests/resources/relations/text.ann @@ -0,0 +1,82 @@ +T1 DISO 24 41 fatigue chronique +A1 Certainty T1 Certain +T2 DISO 49 61 maux de tête +A2 Certainty T2 Certain +T3 Constantes 97 134 tension artérielle est de 145/90 mmHg +T4 Constantes 139 172 glycémie à jeun est de 7.8 mmol/L +T5 Constantes 179 212 saturation en oxygène est de 98 % +T6 Chemical_and_drugs 238 248 Amlodipine +A3 Certainty T6 Certain +A4 Temporality T6 Present +T7 Chemical_and_drugs 263 267 5 mg +A5 Tech T7 strength +T8 Chemical_and_drugs 282 290 comprimé +A6 Tech T8 form +T9 Chemical_and_drugs 300 310 voie orale +A7 Tech T9 route +T10 Chemical_and_drugs 311 328 une fois par jour +A8 Tech T10 dosage +R1 Depend Arg1:T6 Arg2:T7 +R2 Depend Arg1:T6 Arg2:T8 +R3 Depend Arg1:T6 Arg2:T9 +R4 Depend Arg1:T6 Arg2:T10 +T11 Chemical_and_drugs 330 340 Metformine +A9 Certainty T11 Certain +A10 Temporality T11 Present +T12 Chemical_and_drugs 355 380 500 mg deux fois par jour +A11 Tech T12 dosage +T13 Chemical_and_drugs 355 361 500 mg +A12 Tech T13 strength +T14 Temporal 362 380 deux fois par jour +A13 AttTemp T14 Frequency +T15 Temporal 381 401 depuis le 27/05/2022 +A14 AttTemp T15 Duration +A15 AttDate T15 StartDate +A16 RefTemp T15 Absolute +T16 Chemical_and_drugs 403 413 Salbutamol +A17 Certainty T16 Certain +A18 Temporality T16 Present +T17 Chemical_and_drugs 428 434 100 µg +A19 Tech T17 strength +A20 Certainty T17 Conditional +A21 Temporality T17 Present +T18 Chemical_and_drugs 448 458 inhalateur +A22 Tech T18 form +T19 Chemical_and_drugs 471 483 voie inhalée +A23 Tech T19 route +T20 Chemical_and_drugs 502 513 Paracétamol +A24 Certainty T20 Conditional +A25 Temporality T20 Present +T21 Chemical_and_drugs 528 534 500 mg +A26 Tech T21 strength +T22 Chemical_and_drugs 535 541 per os +A27 Tech T22 route +T23 Temporal 560 588 maximum quatre fois par jour +A28 AttTemp T23 Frequency +T24 DISO 598 610 maux de tête +A29 Certainty T24 Conditional +T25 DISO 751 754 HTA +A30 Certainty T25 Certain +T26 Chemical_and_drugs 787 797 Amlodipine +A31 Action T26 Increase +A32 Certainty T26 Certain +A33 Temporality T26 Present +T27 Chemical_and_drugs 800 805 10 mg +A34 Tech T27 dosage +T28 DISO 845 851 asthme +A35 Certainty T28 Certain +T29 DISO 917 929 maux de tête +A36 Certainty T29 Conditional +T30 Chemical_and_drugs 970 981 paracétamol +A37 Certainty T30 Conditional +R5 Depend Arg1:T11 Arg2:T13 +R6 Depend Arg1:T11 Arg2:T12 +R7 Depend Arg1:T11 Arg2:T14 +R8 Depend Arg1:T11 Arg2:T15 +R9 Depend Arg1:T16 Arg2:T17 +R10 Depend Arg1:T16 Arg2:T18 +R11 Depend Arg1:T16 Arg2:T19 +R12 Depend Arg1:T20 Arg2:T21 +R13 Depend Arg1:T20 Arg2:T22 +R14 Depend Arg1:T20 Arg2:T23 +R15 Depend Arg1:T26 Arg2:T27 diff --git a/tests/resources/relations/text.txt b/tests/resources/relations/text.txt new file mode 100644 index 000000000..4c0257620 --- /dev/null +++ b/tests/resources/relations/text.txt @@ -0,0 +1,3 @@ +Le patient présente une fatigue chronique et des maux de tête fréquents. À l'examen clinique, sa tension artérielle est de 145/90 mmHg, sa glycémie à jeun est de 7.8 mmol/L et sa saturation en oxygène est de 98 %. +Traitements en cours : Amlodipine à une dose de 5 mg sous forme de comprimé pris par voie orale une fois par jour. Metformine à une dose de 500 mg deux fois par jour depuis le 27/05/2022. Salbutamol à une dose de 100 µg sous forme d'inhalateur utilisé par voie inhalée en cas de besoin. Paracétamol à une dose de 500 mg per os en cas de besoin, maximum quatre fois par jour pour les maux de tête. +Recommandations : Surveillance régulière de la tension artérielle et de la glycémie. Consultation chez un cardiologue pour évaluation de l'HTA, dans l'attente majoration de l'Amlodipine à 10 mg. Poursuite du traitement actuel pour l'asthme avec usage de l'inhalateur en cas de symptômes. Réévaluation des maux de tête si persistants malgré le traitement par paracétamol. \ No newline at end of file From fd7f6f260666001157fd021662260d88acbcb9ab Mon Sep 17 00:00:00 2001 From: pidoux7 Date: Wed, 3 Jul 2024 15:32:52 +0200 Subject: [PATCH 02/18] feat: add support for reading and dumping relations in standoff files using Brat --- edsnlp/data/converters.py | 181 +++++++++++++++++++++++++++----------- edsnlp/data/standoff.py | 28 +++--- 2 files changed, 142 insertions(+), 67 deletions(-) diff --git a/edsnlp/data/converters.py b/edsnlp/data/converters.py index 65b158318..a8a8c21f8 100644 --- a/edsnlp/data/converters.py +++ b/edsnlp/data/converters.py @@ -66,7 +66,7 @@ def validate_kwargs(converter, kwargs): return {**(d.pop(vd.v_kwargs_name, None) or {}), **d} -class SequenceStr: +class SequenceStr: @classmethod def __get_validators__(cls): yield cls.validate @@ -217,7 +217,7 @@ def __init__( tokenizer: Optional[Tokenizer] = None, span_setter: SpanSetterArg = {"ents": True, "*": True}, span_attributes: Optional[AttributesMappingArg] = None, - span_rel: Optional[AttributesMappingArg] = None, # à voir si on le garde + span_rel: Optional[AttributesMappingArg] = None, # à voir si on le garde keep_raw_attribute_values: bool = False, bool_attributes: SequenceStr = [], default_attributes: AttributesMappingArg = {}, @@ -250,11 +250,13 @@ def __call__(self, obj): Span.set_extension(dst, default=None) ############## Modification pour les relations ############### - dict_entities={} ## dictionnaire pour stocker les entités + dict_entities = {} ## dictionnaire pour stocker les entités for ent in obj.get("entities") or (): - begin = min(f["begin"] for f in ent["fragments"]) # debut de l'entité - end = max(f["end"] for f in ent["fragments"]) # fin de l'entité - dict_entities[ent['entity_id']] = ent['label'] + ';' + str(begin) + ';' + str(end) # stocker les entités + begin = min(f["begin"] for f in ent["fragments"]) # debut de l'entité + end = max(f["end"] for f in ent["fragments"]) # fin de l'entité + dict_entities[ent["entity_id"]] = ( + ent["label"] + ";" + str(begin) + ";" + str(end) + ) # stocker les entités fragments = ( [ { @@ -270,13 +272,18 @@ def __call__(self, obj): fragment["begin"], fragment["end"], label=ent["label"], - alignment_mode="expand", # ajout id + alignment_mode="expand", # ajout id + ) + attributes = ( + {a["label"]: a["value"] for a in ent["attributes"]} + if isinstance(ent["attributes"], list) + else ent["attributes"] ) if self.notes_as_span_attribute and ent["notes"]: ent["attributes"][self.notes_as_span_attribute] = "|".join( note["value"] for note in ent["notes"] ) - for label, value in ent["attributes"].items(): + for label, value in attributes.items(): new_name = ( self.span_attributes.get(label, None) if self.span_attributes is not None @@ -300,7 +307,6 @@ def __call__(self, obj): span._.set(new_name, value) spans.append(span) - set_spans(doc, spans, span_setter=self.span_setter) for attr, value in self.default_attributes.items(): @@ -308,42 +314,65 @@ def __call__(self, obj): if span._.get(attr) is None: span._.set(attr, value) - ############## Modification pour les relations ############### # Ajout des relations en terme de span - if self.span_rel is None and not Span.has_extension( - 'rel' - ): - Span.set_extension('rel', default=[]) - - for rel in obj.get("relations") or (): # itere relation - for label in doc.spans: # itere label source - for i, spa in enumerate(doc.spans[label]): # itere spans source + if self.span_rel is None and not Span.has_extension("rel"): + Span.set_extension("rel", default=[]) + + for rel in obj.get("relations") or (): # itere relation + for label in doc.spans: # itere label source + for i, spa in enumerate(doc.spans[label]): # itere spans source bo = False - - #relations - if dict_entities[rel["from_entity_id"]].split(';') == [label , str(spa.start_char), str(spa.end_char)]: # si l'entité source est la meme que celle du span - for label2 in doc.spans: #itere label target - for j, spa2 in enumerate(doc.spans[label2]): #iter label target - if dict_entities[rel["to_entity_id"]].split(';') == [label2 , str(spa2.start_char), str(spa2.end_char)]: # si l'entité target est la meme que celle du span - relation = {'type': rel['relation_label'], 'target': doc.spans[label2][j]} # creer la relation - doc.spans[label][i]._.rel.append(relation) # ajouter la relation au span - bo = True + + # relations + if dict_entities[rel["from_entity_id"]].split(";") == [ + label, + str(spa.start_char), + str(spa.end_char), + ]: # si l'entité source est la meme que celle du span + for label2 in doc.spans: # itere label target + for j, spa2 in enumerate( + doc.spans[label2] + ): # iter label target + if dict_entities[rel["to_entity_id"]].split(";") == [ + label2, + str(spa2.start_char), + str(spa2.end_char), + ]: # si l'entité target est la meme que celle du span + relation = { + "type": rel["relation_label"], + "target": doc.spans[label2][j], + } # creer la relation + doc.spans[label][i]._.rel.append( + relation + ) # ajouter la relation au span + bo = True break - if bo == True: + if bo: break bo = False # relations inverses - if dict_entities[rel["to_entity_id"]].split(';') == [label , str(spa.start_char), str(spa.end_char)]: - for label2 in doc.spans: + if dict_entities[rel["to_entity_id"]].split(";") == [ + label, + str(spa.start_char), + str(spa.end_char), + ]: + for label2 in doc.spans: for j, spa2 in enumerate(doc.spans[label2]): - if dict_entities[rel["from_entity_id"]].split(';') == [label2 , str(spa2.start_char), str(spa2.end_char)]: - relation = {'type': 'inv_' + rel['relation_label'], 'target': doc.spans[label2][j]} + if dict_entities[rel["from_entity_id"]].split(";") == [ + label2, + str(spa2.start_char), + str(spa2.end_char), + ]: + relation = { + "type": "inv_" + rel["relation_label"], + "target": doc.spans[label2][j], + } doc.spans[label][i]._.rel.append(relation) - bo=True + bo = True break - if bo == True: + if bo: break return doc @@ -390,29 +419,75 @@ def __init__( def __call__(self, doc): spans = get_spans(doc, self.span_getter) + entities = [ + { + "entity_id": i, + "fragments": [ + { + "begin": ent.start_char, + "end": ent.end_char, + } + ], + "attributes": { + obj_name: getattr(ent._, ext_name) + for ext_name, obj_name in self.span_attributes.items() + if ent._.has(ext_name) + }, + "label": ent.label_, + } + for i, ent in enumerate(sorted(dict.fromkeys(spans))) + ] + + # mapping between entities and their `entity_id` + entity_map = { + ( + ent["fragments"][0]["begin"], + ent["fragments"][0]["end"], + ent["label"], + ): ent["entity_id"] + for ent in entities + } + + # doesn't include 'inv_' relations + relations = [] + relation_idx = 1 + for span_label, span_list in doc.spans.items(): + for spa in span_list: + source_entity_id = entity_map.get( + (spa.start_char, spa.end_char, spa.label_) + ) + for rel in spa._.rel: + if not rel["type"].startswith("inv_"): + target_entity_id = entity_map.get( + ( + rel["target"].start_char, + rel["target"].end_char, + rel["target"].label_, + ) + ) + if ( + source_entity_id is not None + and target_entity_id is not None + ): + relations.append( + { + "rel_id": relation_idx, + "from_entity_id": source_entity_id, + "relation_type": rel["type"], + "to_entity_id": target_entity_id, + } + ) + relation_idx += 1 + + # final object obj = { FILENAME: doc._.note_id, "doc_id": doc._.note_id, "text": doc.text, - "entities": [ - { - "entity_id": i, - "fragments": [ - { - "begin": ent.start_char, - "end": ent.end_char, - } - ], - "attributes": { - obj_name: getattr(ent._, ext_name) - for ext_name, obj_name in self.span_attributes.items() - if ent._.has(ext_name) - }, - "label": ent.label_, - } - for i, ent in enumerate(sorted(dict.fromkeys(spans))) - ], + "entities": entities, + "relations": relations, } + return obj @@ -476,7 +551,7 @@ def __init__( *, tokenizer: Optional[PipelineProtocol] = None, span_setter: SpanSetterArg = {"ents": True, "*": True}, - doc_attributes: AttributesMappingArg = {}, + doc_attributes: AttributesMappingArg = {"note_datetime": "note_datetime"}, span_attributes: Optional[AttributesMappingArg] = None, default_attributes: AttributesMappingArg = {}, bool_attributes: SequenceStr = [], diff --git a/edsnlp/data/standoff.py b/edsnlp/data/standoff.py index 6dece8673..6492abddd 100644 --- a/edsnlp/data/standoff.py +++ b/edsnlp/data/standoff.py @@ -264,20 +264,20 @@ def dump_standoff_file( file=f, ) attribute_idx += 1 - - # fmt: off - # if "relations" in doc: - # for i, relation in enumerate(doc["relations"]): - # entity_from = entities_ids[relation["from_entity_id"]] - # entity_to = entities_ids[relation["to_entity_id"]] - # print( - # "R{}\t{} Arg1:{} Arg2:{}\t".format( - # i + 1, str(relation["label"]), entity_from, - # entity_to - # ), - # file=f, - # ) - # fmt: on + # Ajout du traitement des relations + relation_idx = 1 + if "relations" in doc: + for relation in doc["relations"]: + print( + "R{}\t{} Arg1:{} Arg2:{}".format( + relation_idx, + relation["relation_type"], + entities_ids[relation["from_entity_id"]], + entities_ids[relation["to_entity_id"]], + ), + file=f, + ) + relation_idx += 1 class StandoffReader(BaseReader): From 452c52fc1f2b1091cf9613eb218498d6ef2b9ee0 Mon Sep 17 00:00:00 2001 From: pidoux7 Date: Wed, 3 Jul 2024 15:35:26 +0200 Subject: [PATCH 03/18] fix: Fix import statement in relations module 'eds.relations' Added: rule based model for relation prediction using proximity and sentences --- edsnlp/pipes/misc/relations/__init__.py | 2 +- edsnlp/pipes/misc/relations/factory.py | 4 +- edsnlp/pipes/misc/relations/patterns.py | 32 +- edsnlp/pipes/misc/relations/relations.py | 390 +++++++++++++++++------ 4 files changed, 313 insertions(+), 115 deletions(-) diff --git a/edsnlp/pipes/misc/relations/__init__.py b/edsnlp/pipes/misc/relations/__init__.py index e5f49bc5a..5fd446779 100644 --- a/edsnlp/pipes/misc/relations/__init__.py +++ b/edsnlp/pipes/misc/relations/__init__.py @@ -1 +1 @@ -from .relations import RelationsMatcher \ No newline at end of file +from .relations import RelationsMatcher diff --git a/edsnlp/pipes/misc/relations/factory.py b/edsnlp/pipes/misc/relations/factory.py index 2a26f27ed..7904fca3b 100644 --- a/edsnlp/pipes/misc/relations/factory.py +++ b/edsnlp/pipes/misc/relations/factory.py @@ -6,7 +6,7 @@ scheme=None, use_sentences=False, clean_rel=False, - proximity_method = "right", + proximity_method="right", max_dist=45, ) @@ -14,4 +14,4 @@ "eds.relations", assigns=["doc.spans"], deprecated=["relations"], -)(RelationsMatcher) \ No newline at end of file +)(RelationsMatcher) diff --git a/edsnlp/pipes/misc/relations/patterns.py b/edsnlp/pipes/misc/relations/patterns.py index c4cb8fa13..cba035219 100644 --- a/edsnlp/pipes/misc/relations/patterns.py +++ b/edsnlp/pipes/misc/relations/patterns.py @@ -1,17 +1,17 @@ scheme = [ - { - "subject": [{"label": "Chemical_and_drugs", "attr": {"Tech": [None]}}], - "object": [ - { - "label": "Temporal", - "attr": {"AttTemp": [None, "Duration", "Date", "Time"]}, - }, - { - "label": "Chemical_and_drugs", - "attr": {"Tech": ["dosage", "route", "strength", "form"]}, - }, - ], - "type": "Depend", - "inv_type": "inv_Depend", - }, - ] \ No newline at end of file + { + "source": [{"label": "Chemical_and_drugs", "attr": {"Tech": [None]}}], + "target": [ + { + "label": "Temporal", + "attr": {"AttTemp": [None, "Duration", "Date", "Frequency"]}, + }, + { + "label": "Chemical_and_drugs", + "attr": {"Tech": ["dosage", "route", "strength", "form"]}, + }, + ], + "type": "Depend", + "inv_type": "inv_Depend", + }, +] diff --git a/edsnlp/pipes/misc/relations/relations.py b/edsnlp/pipes/misc/relations/relations.py index 9e53bbd6a..49aba0266 100644 --- a/edsnlp/pipes/misc/relations/relations.py +++ b/edsnlp/pipes/misc/relations/relations.py @@ -1,45 +1,199 @@ -from typing import Dict, Iterable, List, Union +import json +from typing import Any, Dict, List, Union +import numpy as np from loguru import logger - -from spacy.tokens import Doc, Span -from typing import Any from numpy.typing import NDArray +from spacy.tokens import Doc, Span from edsnlp.core import PipelineProtocol from edsnlp.pipes.misc.relations import patterns -import math as m -import numpy as np -import json class RelationsMatcher: - """ A spaCy EDSNLP pipeline component to find relations between entities based on their proximity. - scheme = [ - { - "subject": [{"label": "Chemical_and_drugs", "attr": {"Tech": [None]}}], - "object": [ - { - "label": "Temporal", - "attr": {"AttTemp": [None, "Duration", "Date", "Time"]}, - }, - { - "label": "Chemical_and_drugs", - "attr": {"Tech": ["dosage", "route", "strength", "form"]}, - }, - ], - "type": "Depend", - "inv_type": "inv_Depend", - }, - ] - """ - + ''' + The `eds.relations` component links source and target Named Entities +/- Attributes. + This component is rule-based and utilizes character proximity + to determine relationships. + + + Examples + -------- + In this simple example, we extract drugs and dates \ + from a text and link them together. + ```python + import edsnlp, edsnlp.pipes as eds + + text = """ + Prise pendant 3 semaines d'Amlodipine 5mg per os une fois par jour \ + mais l'HTA reste mal contrôlée. + Metformine 500 mg deux fois par jour à partir du 27/05/2022. + Consultation chez un cardiologue le 11/07 pour évaluation de l'HTA, \ + dans l'attente majoration de l'AMLODIPINE à 10 mg. + """ + + scheme = { + "source": [{"label": "drug", "attr": None}], + "target": [ + {"label": "dates", "attr": None}, + {"label": "durations", "attr": None}, + ], + "type": "Temporal", + "inv_type": "inv_Temporal", + } + + nlp = edsnlp.blank("eds") + + # Extraction of entities + nlp.add_pipe("eds.drugs") + nlp.add_pipe("eds.dates") + # Extraction of sentences + nlp.add_pipe("eds.sentences") + # Extraction of relations + nlp.add_pipe( + "eds.relations", + config={ + "scheme": scheme, + "use_sentences": True, + "clean_rel": True, + "proximity_method": "sym", + "max_dist": 60, + }, + ) + doc = nlp(text) + + for label in doc.spans: + print("Label: ", label, "\t Entities :", doc.spans[label]) + for span in doc.spans[label]: + print("\t Entity :", span, "\t Relations :", span._.rel) + + # Out: Label: drug \ + # Entities : [Amlodipine, Metformine, AMLODIPINE] + # Entity : Amlodipine Relations : \ + # [{'type': 'Temporal', target': pendant 3 semaines}] + # Entity : Metformine Relations : \ + [{"type": "Temporal", "target": 27 / 05 / 2022}] + # Entity : AMLODIPINE Relations : [] + + # Label: dates Entities : [27/05/2022, 11/07] + # Entity : 27/05/2022 Relations : \ + [{"type": "inv_Temporal", "target": Metformine}] + # Entity : 11/07 Relations : [] + + # Label: durations Entities : [pendant 3 semaines] + # Entity : pendant 3 semaines \ + Relations: [{"type": "inv_Temporal", "target": Amlodipine}] + + # Label: periods Entities : [] + ``` + + Extensions + ---------- + The `eds.relations` pipeline adds and declares one extension + on the `Span` objects called `rel`. By default rel is an empty list. + + The `rel` extension is a list of dictionaries + containing the type of the relation and the target `Span`. + It automatically adds the inverse relation to the target `Span`. + + Parameters + ---------- + nlp : PipelineProtocol + The pipeline object + name : str + Name of the component + scheme : Union[Union[Dict, List[Dict]],str] + The scheme to use to match the relations + use_sentences: bool = True + Whether or not to use the `eds.sentences` matcher to improve results + proximity_method: str = "right" + The method to use to calculate the proximity between the entities + "sym" : symmetrical distance + "start" : distance between the start char of the entities + "end" : distance between the end char of the entities + "middle" : distance between the middle of the entities + "right" : distance between the end of the source and the start of the target + "left" : distance between the end of the target and the start of the source + max_dist: int = 45 + The maximum distance between the entities to consider them as related + clean_rel: bool = True + Whether or not to clean the relations before adding new ones + + Scheme + ------ + It can be a dictionary (one relation), \ + a list of dictionaries (one or more relations) + or a string indicating the path of a json file. + + Each dictionary should contain the keys \ + `source`, `target`, `type` and `inv_type`. + + `source` and `target` are lists of dictionaries \ + containing the keys `label` and `attr`. + + `label` is the label of the entity to match. + + `attr` is a dictionary containing the attributes \ + to match on or None if no attribute is needed. + + `type` is the type of the relation. + + `inv_type` is the inverse type of the relation. + ```json + [ + { + "source": [ + { + "label": "Chemical_and_drugs", + "attr": { + "Tech": [ + null + ] + } + } + ], + "target": [ + { + "label": "Temporal", + "attr": { + "AttTemp": [ + "Duration", + "Date", + "Frequency", + "Time" + ] + } + }, + { + "label": "Chemical_and_drugs", + "attr": { + "Tech": [ + "dosage", + "route", + "strength", + "form" + ] + } + } + ], + "type": "Depend", + "inv_type": "inv_Depend" + } + ] + ``` + + Authors and citation + -------------------- + The `eds.relations` was developed by AP-HP's Data Science team. + + ''' + def __init__( self, nlp: PipelineProtocol, name: str = "relations", *, - scheme: Union[Union[Dict, List[Dict]],str] = None, + scheme: Union[Union[Dict, List[Dict]], str] = None, use_sentences: bool = False, proximity_method: str = "right", clean_rel: bool = True, @@ -53,7 +207,6 @@ def __init__( if scheme is None: scheme = patterns.scheme if isinstance(scheme, str): - #ouvrir le fichier json et le lire pour le mettre dans une variable if not scheme.endswith(".json"): raise ValueError("scheme must be a json file") with open(scheme) as f: @@ -91,56 +244,59 @@ def __init__( self.max_dist = max_dist self.set_extensions() - + def check_scheme(self, schemes): for scheme in schemes: if not isinstance(scheme, dict): raise ValueError("scheme must be a dictionary") - if "subject" not in scheme: - raise ValueError("scheme must contain a 'subject' key") - if "object" not in scheme: - raise ValueError("scheme must contain an 'object' key") + if "source" not in scheme: + raise ValueError("scheme must contain a 'source' key") + if "target" not in scheme: + raise ValueError("scheme must contain an 'target' key") if "type" not in scheme: raise ValueError("scheme must contain a 'type' key") if "inv_type" not in scheme: raise ValueError("scheme must contain an 'inv_type' key") - if not isinstance(scheme["subject"], list): - raise ValueError("scheme['subject'] must be a list") - if not isinstance(scheme["object"], list): - raise ValueError("scheme['object'] must be a list") + if not isinstance(scheme["source"], list): + raise ValueError("scheme['source'] must be a list") + if not isinstance(scheme["target"], list): + raise ValueError("scheme['target'] must be a list") if not isinstance(scheme["type"], str): raise ValueError("scheme['type'] must be a string") if not isinstance(scheme["inv_type"], str): raise ValueError("scheme['inv_type'] must be a string") - for sub in scheme["subject"]: + for sub in scheme["source"]: if not isinstance(sub, dict): - raise ValueError("scheme['subject'] must contain dictionaries") + raise ValueError("scheme['source'] must contain dictionaries") if "label" not in sub: - raise ValueError("scheme['subject'] must contain a 'label' key") + raise ValueError("scheme['source'] must contain a 'label' key") if not isinstance(sub["label"], str): - raise ValueError("scheme['subject']['label'] must be a string") + raise ValueError("scheme['source']['label'] must be a string") if "attr" in sub: if sub["attr"] is not None and not isinstance(sub["attr"], dict): - raise ValueError("scheme['subject']['attr'] must be a dictionary or None") - for obj in scheme["object"]: + raise ValueError( + "scheme['source']['attr'] must be a dictionary or None" + ) + for obj in scheme["target"]: if not isinstance(obj, dict): - raise ValueError("scheme['object'] must contain dictionaries") + raise ValueError("scheme['target'] must contain dictionaries") if "label" not in obj: - raise ValueError("scheme['object'] must contain a 'label' key") + raise ValueError("scheme['target'] must contain a 'label' key") if not isinstance(obj["label"], str): - raise ValueError("scheme['object']['label'] must be a string") + raise ValueError("scheme['target']['label'] must be a string") if "attr" in obj: if obj["attr"] is not None and not isinstance(obj["attr"], dict): - raise ValueError("scheme['object']['attr'] must be a dictionary or None") + raise ValueError( + "scheme['target']['attr'] must be a dictionary or None" + ) return True @classmethod def set_extensions(cls) -> None: - """Set the extension rel for the Span object. - """ + """Set the extension rel for the Span target.""" if not Span.has_extension("rel"): Span.set_extension("rel", default=[]) - + def clean_relations(self, doc: Doc) -> Doc: """Remove the relations from the doc @@ -157,7 +313,8 @@ def clean_relations(self, doc: Doc) -> Doc: return doc def __call__(self, doc: Doc) -> Doc: - """find the relations in the doc based on the proximity of the entities attributes + """find the relations in the doc based \ + on the proximity of the entities attributes Args: doc (Doc): the doc to be processed @@ -178,43 +335,60 @@ def __call__(self, doc: Doc) -> Doc: for i, span_obj in enumerate(dict_r[r]["spans_obj"]): if distances[min_distance_indices[i]][i] <= self.max_dist: span_sub = dict_r[r]["spans_sub"][min_distance_indices[i]] - if self.use_sentences and not self.sentences(doc, span_obj["span"], span_sub["span"]): + if self.use_sentences and not self.sentences( + doc, span_obj["span"], span_sub["span"] + ): continue - doc.spans[span_obj['label']][span_obj['num_span']]._.rel.append( - {"type": dict_r[r]["inv_type"], "target": doc.spans[span_sub["label"]][span_sub["num_span"]]} - ) - - doc.spans[span_sub['label']][span_sub['num_span']]._.rel.append( - {"type": dict_r[r]["type"], "target": doc.spans[span_obj["label"]][span_obj["num_span"]]} - ) + doc.spans[span_obj["label"]][span_obj["num_span"]]._.rel.append( + { + "type": dict_r[r]["inv_type"], + "target": doc.spans[span_sub["label"]][ + span_sub["num_span"] + ], + } + ) + + doc.spans[span_sub["label"]][span_sub["num_span"]]._.rel.append( + { + "type": dict_r[r]["type"], + "target": doc.spans[span_obj["label"]][ + span_obj["num_span"] + ], + } + ) return doc - + def sentences(self, doc: Doc, span_obj: Span, span_sub: Span) -> bool: - """ Check if span_obj and span_sub are in the same sentence. + """Check if span_obj and span_sub are in the same sentence. Args: - doc (Doc): EDSNLP Doc object + doc (Doc): EDSNLP Doc target span_obj (Span): span representing the target span_sub (Span): span representing the source Returns: - bool: True if span_obj and span_sub are in the same sentence, False otherwise. + bool: True if span_obj and span_sub \ + are in the same sentence, False otherwise. """ for sent in doc.sents: - if span_obj.start >= sent.start and span_obj.end <= sent.end and span_sub.start >= sent.start and span_sub.end <= sent.end: + if ( + span_obj.start >= sent.start + and span_obj.end <= sent.end + and span_sub.start >= sent.start + and span_sub.end <= sent.end + ): return True return False - def find_relations(self, doc: Doc) -> Dict: """ - Detect the potential subjects and objects in the document + Detect the potential sources and targets in the document Args: - doc (Doc): EDSNLP Doc object + doc (Doc): EDSNLP Doc target Returns: - Dict: dict containing the potential subjects and objects + Dict: dict containing the potential sources and targets """ dict_r = {} for r, relation in enumerate(self.scheme): @@ -226,8 +400,8 @@ def find_relations(self, doc: Doc) -> Dict: "type": relation["type"], "inv_type": relation["inv_type"], } - # Treatment of objects - for obj in relation["object"]: + # Treatment of targets + for obj in relation["target"]: label_obj = obj["label"] attr_obj = obj["attr"] if label_obj in doc.spans: @@ -239,16 +413,28 @@ def find_relations(self, doc: Doc) -> Dict: dict_r[r]["mat_obj"].append( [span_obj.start_char, span_obj.end_char] ) - dict_r[r]["spans_obj"].append({'label': label_obj, 'num_span': num_span_obj, 'span': span_obj}) + dict_r[r]["spans_obj"].append( + { + "label": label_obj, + "num_span": num_span_obj, + "span": span_obj, + } + ) else: for num_span_obj, span_obj in enumerate(doc.spans[label_obj]): dict_r[r]["mat_obj"].append( [span_obj.start_char, span_obj.end_char] ) - dict_r[r]["spans_obj"].append({'label': label_obj, 'num_span': num_span_obj, 'span': span_obj}) + dict_r[r]["spans_obj"].append( + { + "label": label_obj, + "num_span": num_span_obj, + "span": span_obj, + } + ) - # Treatment of subjects - for sub in relation["subject"]: + # Treatment of sources + for sub in relation["source"]: label_sub = sub["label"] attr_sub = sub["attr"] if label_sub in doc.spans: @@ -260,13 +446,25 @@ def find_relations(self, doc: Doc) -> Dict: dict_r[r]["mat_sub"].append( [span_sub.start_char, span_sub.end_char] ) - dict_r[r]["spans_sub"].append({'label': label_sub, 'num_span': num_span_sub, 'span': span_sub}) + dict_r[r]["spans_sub"].append( + { + "label": label_sub, + "num_span": num_span_sub, + "span": span_sub, + } + ) else: for num_span_sub, span_sub in enumerate(doc.spans[label_sub]): dict_r[r]["mat_sub"].append( [span_sub.start_char, span_sub.end_char] ) - dict_r[r]["spans_sub"].append({'label': label_sub, 'num_span': num_span_sub, 'span': span_sub}) + dict_r[r]["spans_sub"].append( + { + "label": label_sub, + "num_span": num_span_sub, + "span": span_sub, + } + ) # Convert lists to numpy arrays for easier manipulation later for r in dict_r: @@ -275,36 +473,38 @@ def find_relations(self, doc: Doc) -> Dict: return dict_r - def calculate_min_distances(self, subjects:NDArray[Any], objects:NDArray[Any]) -> NDArray[Any]: - """ calculate the minimum distance between subjects and objects + def calculate_min_distances( + self, sources: NDArray[Any], targets: NDArray[Any] + ) -> NDArray[Any]: + """calculate the minimum distance between sources and targets Args: - subjects (NDArray[Any]): entities to be used as subjects - objects (NDArray[Any]): entities to be used as objects + sources (NDArray[Any]): entities to be used as sources + targets (NDArray[Any]): entities to be used as targets Returns: - NDArray[Any]: the index of the subject that is closest to the object + NDArray[Any]: the index of the source that is closest to the target """ - subjects_expanded = subjects[:, np.newaxis, :] - objects_expanded = objects[np.newaxis, :, :] + sources_expanded = sources[:, np.newaxis, :] + targets_expanded = targets[np.newaxis, :, :] # calculate the distances between the entities - distance_start_to_end = subjects_expanded[:, :, 0] - objects_expanded[:, :, 1] - distance_end_to_start = objects_expanded[:, :, 0] - subjects_expanded[:, :, 1] - distance_start_to_start = subjects_expanded[:, :, 0] - objects_expanded[:, :, 0] - distance_end_to_end = objects_expanded[:, :, 1] - subjects_expanded[:, :, 1] + distance_start_to_end = sources_expanded[:, :, 0] - targets_expanded[:, :, 1] + distance_end_to_start = targets_expanded[:, :, 0] - sources_expanded[:, :, 1] + distance_start_to_start = sources_expanded[:, :, 0] - targets_expanded[:, :, 0] + distance_end_to_end = targets_expanded[:, :, 1] - sources_expanded[:, :, 1] distance_middle = ( - subjects_expanded[:, :, 0] + subjects_expanded[:, :, 1] - ) / 2 - (objects_expanded[:, :, 0] + objects_expanded[:, :, 1]) / 2 + np.abs(distance_start_to_start) + np.abs(distance_end_to_end) + ) / 2 if self.proximity_method == "sym": distance_middle = np.abs( np.minimum(distance_start_to_start, distance_end_to_end) ) # determine the mask for the left and right side of the entities - mask_left = objects_expanded[:, :, 1] <= subjects_expanded[:, :, 0] - mask_right = subjects_expanded[:, :, 1] <= objects_expanded[:, :, 0] + mask_left = targets_expanded[:, :, 1] <= sources_expanded[:, :, 0] + mask_right = sources_expanded[:, :, 1] <= targets_expanded[:, :, 0] # Assign the distances based on the mask distances = np.where(mask_left, np.abs(distance_start_to_end), np.inf) @@ -342,7 +542,7 @@ def calculate_min_distances(self, subjects:NDArray[Any], objects:NDArray[Any]) - min_distance_indices = np.argmin(distances, axis=0) return min_distance_indices, distances - + def filter_spans( self, span: Span, attr_name: str, attr_values: list, label: str ) -> bool: @@ -358,9 +558,7 @@ def filter_spans( bool: _description_ """ # Get the attribute value or None if it doesn't exist - attr_value = getattr(span._, attr_name, None) + attr_value = getattr(span._, attr_name, None) if attr_value in attr_values: return True return False - - From 9ef378d3a926acd28c81223c2cb272d56b124421 Mon Sep 17 00:00:00 2001 From: pidoux7 Date: Wed, 3 Jul 2024 15:44:15 +0200 Subject: [PATCH 04/18] Added: relation testing for eds.relations and brat relations connector --- tests/pipelines/misc/test_relations.py | 44 ++++++++++++------------ tests/resources/relations/relations.json | 4 +-- tests/resources/relations/text.txt | 2 +- 3 files changed, 25 insertions(+), 25 deletions(-) diff --git a/tests/pipelines/misc/test_relations.py b/tests/pipelines/misc/test_relations.py index 8a8437818..9a08c4274 100644 --- a/tests/pipelines/misc/test_relations.py +++ b/tests/pipelines/misc/test_relations.py @@ -1,22 +1,21 @@ -import sys import os -from pytest import mark + import pytest +from pytest import mark from spacy.tokens import Doc, Span -# Assurez-vous que le chemin vers edsnlp est en premier dans sys.path -sys.path.insert(0, "/home/pidoux/edsnlp") - # Importation des modules après avoir ajouté le chemin import edsnlp -import edsnlp.pipes.core as eds + @mark.parametrize("use_sentences", [True, False]) @mark.parametrize("clean_rel", [True, False]) -@mark.parametrize("proximity_method", ["sym", "right", "left", "middle", "start", "end"]) -@mark.parametrize("max_dist", [1, 40, 100]) +@mark.parametrize( + "proximity_method", ["sym", "right", "left", "middle", "start", "end"] +) +@mark.parametrize("max_dist", [1, 45, 100]) def test_relations(use_sentences, clean_rel, proximity_method, max_dist): - dossier = "../../resources/relations/" + dossier = "../../resources/relations/" doc_iterator = edsnlp.data.read_standoff(dossier) corpus = list(doc_iterator) assert len(corpus) > 0 @@ -25,11 +24,11 @@ def test_relations(use_sentences, clean_rel, proximity_method, max_dist): for label in doc.spans: for span in doc.spans[label]: assert isinstance(span, Span) - assert span.has_extension('rel') + assert span.has_extension("rel") for rel in span._.rel: - assert isinstance(rel['target'], Span) - assert isinstance(rel['type'], str) - assert rel['type'] == 'Depend' or rel['type'] == 'inv_Depend' + assert isinstance(rel["target"], Span) + assert isinstance(rel["type"], str) + assert rel["type"] == "Depend" or rel["type"] == "inv_Depend" nlp = edsnlp.blank("eds") nlp.add_pipe("eds.sentences") @@ -37,10 +36,10 @@ def test_relations(use_sentences, clean_rel, proximity_method, max_dist): "eds.relations", config={ "scheme": os.path.join(dossier, "relations.json"), - "use_sentences": use_sentences, - "clean_rel": clean_rel, - "proximity_method": proximity_method, - "max_dist": max_dist, + "use_sentences": use_sentences, + "clean_rel": clean_rel, + "proximity_method": proximity_method, + "max_dist": max_dist, }, ) @@ -50,11 +49,12 @@ def test_relations(use_sentences, clean_rel, proximity_method, max_dist): for span in doc.spans[label]: print(span, span._.rel) assert isinstance(span, Span) - assert span.has_extension('rel') + assert span.has_extension("rel") for rel in span._.rel: - assert isinstance(rel['target'], Span) - assert isinstance(rel['type'], str) - assert rel['type'] == 'Depend' or rel['type'] == 'inv_Depend' - + assert isinstance(rel["target"], Span) + assert isinstance(rel["type"], str) + assert rel["type"] == "Depend" or rel["type"] == "inv_Depend" + + if __name__ == "__main__": pytest.main() diff --git a/tests/resources/relations/relations.json b/tests/resources/relations/relations.json index a3d77fe13..8e9e19387 100644 --- a/tests/resources/relations/relations.json +++ b/tests/resources/relations/relations.json @@ -1,6 +1,6 @@ [ { - "subject": [ + "source": [ { "label": "Chemical_and_drugs", "attr": { @@ -10,7 +10,7 @@ } } ], - "object": [ + "target": [ { "label": "Temporal", "attr": { diff --git a/tests/resources/relations/text.txt b/tests/resources/relations/text.txt index 4c0257620..4ed1a3e9f 100644 --- a/tests/resources/relations/text.txt +++ b/tests/resources/relations/text.txt @@ -1,3 +1,3 @@ Le patient présente une fatigue chronique et des maux de tête fréquents. À l'examen clinique, sa tension artérielle est de 145/90 mmHg, sa glycémie à jeun est de 7.8 mmol/L et sa saturation en oxygène est de 98 %. Traitements en cours : Amlodipine à une dose de 5 mg sous forme de comprimé pris par voie orale une fois par jour. Metformine à une dose de 500 mg deux fois par jour depuis le 27/05/2022. Salbutamol à une dose de 100 µg sous forme d'inhalateur utilisé par voie inhalée en cas de besoin. Paracétamol à une dose de 500 mg per os en cas de besoin, maximum quatre fois par jour pour les maux de tête. -Recommandations : Surveillance régulière de la tension artérielle et de la glycémie. Consultation chez un cardiologue pour évaluation de l'HTA, dans l'attente majoration de l'Amlodipine à 10 mg. Poursuite du traitement actuel pour l'asthme avec usage de l'inhalateur en cas de symptômes. Réévaluation des maux de tête si persistants malgré le traitement par paracétamol. \ No newline at end of file +Recommandations : Surveillance régulière de la tension artérielle et de la glycémie. Consultation chez un cardiologue pour évaluation de l'HTA, dans l'attente majoration de l'Amlodipine à 10 mg. Poursuite du traitement actuel pour l'asthme avec usage de l'inhalateur en cas de symptômes. Réévaluation des maux de tête si persistants malgré le traitement par paracétamol. From 943ecd318d5f52b7faa2f6e18925b5b82d8862ad Mon Sep 17 00:00:00 2001 From: pidoux7 Date: Wed, 3 Jul 2024 15:47:33 +0200 Subject: [PATCH 05/18] Added: relations module documentation Fix : relations pipe import --- docs/concepts/torch-component.md | 2 +- docs/pipes/misc/index.md | 1 + docs/pipes/misc/relations.md | 8 ++++++++ mkdocs.yml | 15 ++++++++++----- 4 files changed, 20 insertions(+), 6 deletions(-) create mode 100644 docs/pipes/misc/relations.md diff --git a/docs/concepts/torch-component.md b/docs/concepts/torch-component.md index f946868e6..7065b9b14 100644 --- a/docs/concepts/torch-component.md +++ b/docs/concepts/torch-component.md @@ -82,7 +82,7 @@ In EDS-NLP, sharing a subcomponent is simply done by sharing the object between eds.ner_crf( ..., embedding=eds.transformer( - model_name="bert-base-uncased", + model="bert-base-uncased", window=128, stride=96, ), diff --git a/docs/pipes/misc/index.md b/docs/pipes/misc/index.md index d9df52c6e..c46966921 100644 --- a/docs/pipes/misc/index.md +++ b/docs/pipes/misc/index.md @@ -16,5 +16,6 @@ For instance, the date detection and normalisation pipeline falls in this catego | `eds.sections` | Section detection | | `eds.reason` | Rule-based hospitalisation reason detection | | `eds.tables` | Tables detection | +| `eds.relations` | Relations extraction | diff --git a/docs/pipes/misc/relations.md b/docs/pipes/misc/relations.md new file mode 100644 index 000000000..5157b9db1 --- /dev/null +++ b/docs/pipes/misc/relations.md @@ -0,0 +1,8 @@ +# Relations {: #edsnlp.pipes.misc.relations.factory.create_component } + +::: edsnlp.pipes.misc.relations.factory.create_component + options: + heading_level: 2 + show_bases: false + show_source: false + only_class_level: true diff --git a/mkdocs.yml b/mkdocs.yml index 5878bfad8..360876410 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -78,6 +78,7 @@ nav: - pipes/misc/sections.md - pipes/misc/reason.md - pipes/misc/tables.md + - pipes/misc/relations.md - Named Entity Recognition: - Overview: pipes/ner/index.md - Scores: @@ -119,11 +120,12 @@ nav: - Trainable components: - pipes/trainable/index.md - - pipes/trainable/embeddings/transformer.md - - pipes/trainable/embeddings/text_cnn.md - - pipes/trainable/embeddings/span_pooler.md - - pipes/trainable/ner.md - - pipes/trainable/span-qualifier.md + - 'Transformer': pipes/trainable/embeddings/transformer.md + - 'Text CNN': pipes/trainable/embeddings/text_cnn.md + - 'Span Pooler': pipes/trainable/embeddings/span_pooler.md + - 'NER': pipes/trainable/ner.md + - 'Span Classifier': pipes/trainable/span-classifier.md + - 'Span Linker': pipes/trainable/span-linker.md - tokenizers.md - Data Connectors: - data/index.md @@ -180,6 +182,9 @@ hooks: - docs/scripts/plugin.py plugins: + - redirects: + redirect_maps: + 'pipes/trainable/span-qualifier.md': 'pipes/trainable/span-classifier.md' - search - minify: minify_html: true From 47add63807b139c41a3c9f57c3eab3a8b017c494 Mon Sep 17 00:00:00 2001 From: pidoux7 Date: Wed, 3 Jul 2024 15:50:44 +0200 Subject: [PATCH 06/18] Fix: modified gitignore in order to run relations test --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index bf0d160a0..f8930d9a6 100644 --- a/.gitignore +++ b/.gitignore @@ -53,6 +53,7 @@ _build/ *.tar.gz *.tsv *.ann +!text.ann # Editors .idea From 0b4eda6eeb5d2cc7ebb6f93e5912793e42c2cb54 Mon Sep 17 00:00:00 2001 From: pidoux7 Date: Wed, 3 Jul 2024 15:56:02 +0200 Subject: [PATCH 07/18] Modified: changelog --- changelog.md | 114 +++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 114 insertions(+) diff --git a/changelog.md b/changelog.md index 5e82ed0c5..c1329fce9 100644 --- a/changelog.md +++ b/changelog.md @@ -1,5 +1,119 @@ # Changelog +## Unreleased + +### Added +- Relation implementation in `doc.spans["