diff --git a/.github/workflows/documentation.yml b/.github/workflows/documentation.yml index d182c0ae0c..8ea2de096a 100644 --- a/.github/workflows/documentation.yml +++ b/.github/workflows/documentation.yml @@ -31,9 +31,9 @@ jobs: - name: Install dependencies run: | - pip install '.[docs]' + pip install . --group docs # uv venv -# uv pip install '.[docs]' +# uv pip install . --group docs - name: Set up Git run: | diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 69f7f9a3d0..80920977c2 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -88,7 +88,7 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - pip install '.[docs]' + pip install . --group docs - name: Set up Git run: | diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index ff6662d46c..5eb647959e 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -61,16 +61,16 @@ jobs: cache: 'pip' - name: Install dependencies - run: pip install -e ".[dev]" + run: pip install -e . --group dev if: matrix.python-version != '3.9' && matrix.python-version != '3.10' && matrix.python-version != '3.11' && matrix.python-version != '3.12' - name: Install dependencies - run: pip install -e ".[dev,setup]" + run: pip install -e . --group dev --group setup if: matrix.python-version == '3.9' - name: Install dependencies # skip ML tests for 3.10 and 3.11 - run: pip install -e ".[dev-no-ml]" + run: pip install -e . --group dev-no-ml if: matrix.python-version == '3.10' || matrix.python-version == '3.11' || matrix.python-version == '3.12' - name: Test with Pytest on Python ${{ matrix.python-version }} @@ -118,7 +118,7 @@ jobs: cache: 'pip' - name: Install dependencies - run: pip install -e ".[docs]" + run: pip install -e . --group docs - name: Set up Git run: | diff --git a/Makefile b/Makefile deleted file mode 100644 index b4ceada36c..0000000000 --- a/Makefile +++ /dev/null @@ -1,29 +0,0 @@ - -.ONESHELL: - SHELL:=/bin/bash - -.PHONY: create-env install documentation test - -default: - @echo "Call a specific subcommand: create-env,install,documentation,test" - -.venv: - python -m venv .venv - -create-env: .venv - -install : .venv - . .venv/bin/activate - pip install -r '.[dev,setup]'.txt - python scripts/conjugate_verbs.py - pip install -e . - pre-commit install - -documentation: .venv - . .venv/bin/activate - pip install -e '.[docs]' - mkdocs serve - -test: .venv - . .venv/bin/activate - python -m pytest diff --git a/changelog.md b/changelog.md index 1959892052..d2a28fb422 100644 --- a/changelog.md +++ b/changelog.md @@ -11,6 +11,10 @@ - New `eds.explode` pipe that splits one document into multiple documents, one per span yielded by its `span_getter` parameter, each new document containing exactly that single span. - New `Training a span classifier` tutorial, and reorganized deep-learning docs - `ScheduledOptimizer` now warns when a parameter selector does not match any parameter. +- New trainable `eds.relation_detector_ffn` component to detect relations between entities. These relations are stored in each entity: `head._.rel[relation_label] = [tail1, tail2, ...]`. +- Load "Status" annotator notes as `status` dict attribute +- New `attention` pooling mode in +- Support different poolers for span embedding and inter-span embeddings in `eds.relation_detector_ffn` ## Fixed diff --git a/contributing.md b/contributing.md index fcb6e46659..9c7872d20c 100644 --- a/contributing.md +++ b/contributing.md @@ -24,7 +24,7 @@ $ python -m venv venv $ source venv/bin/activate # Install the package with common, dev, setup dependencies in editable mode -$ pip install -e '.[dev,setup]' +$ pip install -e . --group dev --group setup # And build resources $ python scripts/conjugate_verbs.py ``` @@ -113,7 +113,7 @@ We use `MkDocs` for EDS-NLP's documentation. You can checkout the changes you ma ```console # Install the requirements -$ pip install -e '.[docs]' +$ pip install -e . --group docs ---> 100% color:green Installation successful diff --git a/docs/tutorials/training-ner.md b/docs/tutorials/training-ner.md index c3c151e823..40ea9745de 100644 --- a/docs/tutorials/training-ner.md +++ b/docs/tutorials/training-ner.md @@ -41,7 +41,7 @@ dependencies = [ "sentencepiece>=0.1.96" ] -[project.optional-dependencies] +[dependency-groups] dev = [ "dvc>=2.37.0; python_version >= '3.8'", "pandas>=1.1.0,<2.0.0; python_version < '3.8'", @@ -59,7 +59,7 @@ pip install uv # skip the next two lines if you do not want a venv uv venv .venv source .venv/bin/activate -uv pip install -e ".[dev]" -p $(uv python find) +uv pip install -e . --group dev -p $(uv python find) ``` ## Training the model diff --git a/docs/tutorials/training-span-classifier.md b/docs/tutorials/training-span-classifier.md index ce8cd61d29..b5d891ca9a 100644 --- a/docs/tutorials/training-span-classifier.md +++ b/docs/tutorials/training-span-classifier.md @@ -40,7 +40,7 @@ dependencies = [ "sentencepiece>=0.1.96" ] -[project.optional-dependencies] +[dependency-groups] dev = [ "dvc>=2.37.0; python_version >= '3.8'", "pandas>=1.4.0,<2.0.0; python_version >= '3.8'", @@ -56,7 +56,7 @@ We recommend using a virtual environment and [uv](https://docs.astral.sh/uv/): pip install uv uv venv .venv source .venv/bin/activate -uv pip install -e ".[dev]" +uv pip install -e . --group dev ``` ## Creating the dataset diff --git a/docs/tutorials/tuning.md b/docs/tutorials/tuning.md index 16d09b3d17..e9fe39a9a5 100644 --- a/docs/tutorials/tuning.md +++ b/docs/tutorials/tuning.md @@ -43,7 +43,7 @@ dependencies = [ "configobj>=5.0.9", ] -[project.optional-dependencies] +[dependency-groups] dev = [ "dvc>=2.37.0; python_version >= '3.8'", "pandas>=1.1.0,<2.0.0; python_version < '3.8'", @@ -61,7 +61,7 @@ pip install uv # skip the next two lines if you do not want a venv uv venv .venv source .venv/bin/activate -uv pip install -e ".[dev]" -p $(uv python find) +uv pip install -e . --group dev -p $(uv python find) ``` ## 2. Tuning a model diff --git a/edsnlp/data/converters.py b/edsnlp/data/converters.py index 5ce9343826..e0927bdc5d 100644 --- a/edsnlp/data/converters.py +++ b/edsnlp/data/converters.py @@ -243,76 +243,101 @@ def __init__( def __call__(self, obj, tokenizer=None): # tok = get_current_tokenizer() if self.tokenizer is None else self.tokenizer - tok = tokenizer or self.tokenizer or get_current_tokenizer() - doc = tok(obj["text"] or "") - doc._.note_id = obj.get("doc_id", obj.get(FILENAME)) - - spans = [] - - for dst in ( - *(() if self.span_attributes is None else self.span_attributes.values()), - *self.default_attributes, - ): - if not Span.has_extension(dst): - Span.set_extension(dst, default=None) - - for ent in obj.get("entities") or (): - fragments = ( - [ - { - "begin": min(f["begin"] for f in ent["fragments"]), - "end": max(f["end"] for f in ent["fragments"]), - } - ] - if not self.split_fragments - else ent["fragments"] - ) - for fragment in fragments: - span = doc.char_span( - fragment["begin"], - fragment["end"], - label=ent["label"], - alignment_mode="expand", - ) - attributes = ( - {a["label"]: a["value"] for a in ent["attributes"]} - if isinstance(ent["attributes"], list) - else ent["attributes"] + note_id = obj.get("doc_id", obj.get(FILENAME)) + try: + tok = tokenizer or self.tokenizer or get_current_tokenizer() + doc = tok(obj["text"] or "") + doc._.note_id = note_id + + entities = {} + spans = [] + + for dst in ( + *( + () + if self.span_attributes is None + else self.span_attributes.values() + ), + *self.default_attributes, + ): + if not Span.has_extension(dst): + Span.set_extension(dst, default=None) + + for ent in obj.get("entities") or (): + fragments = ( + [ + { + "begin": min(f["begin"] for f in ent["fragments"]), + "end": max(f["end"] for f in ent["fragments"]), + } + ] + if not self.split_fragments + else ent["fragments"] ) - if self.notes_as_span_attribute and ent["notes"]: - ent["attributes"][self.notes_as_span_attribute] = "|".join( - note["value"] for note in ent["notes"] + for fragment in fragments: + span = doc.char_span( + fragment["begin"], + fragment["end"], + label=ent["label"], + alignment_mode="expand", ) - for label, value in attributes.items(): - new_name = ( - self.span_attributes.get(label, None) - if self.span_attributes is not None - else label + attributes = ( + {} + if "attributes" not in ent + else {a["label"]: a["value"] for a in ent["attributes"]} + if isinstance(ent["attributes"], list) + else ent["attributes"] ) - if self.span_attributes is None and not Span.has_extension( - new_name - ): - Span.set_extension(new_name, default=None) - - if new_name: - value = True if value is None else value - if not self.keep_raw_attribute_values: - value = ( - True - if value in ("True", "true") - else False - if value in ("False", "false") - else value - ) - span._.set(new_name, value) - - spans.append(span) + if self.notes_as_span_attribute and ent["notes"]: + ent["attributes"][self.notes_as_span_attribute] = "|".join( + note["value"] for note in ent["notes"] + ) + for label, value in attributes.items(): + new_name = ( + self.span_attributes.get(label, None) + if self.span_attributes is not None + else label + ) + if self.span_attributes is None and not Span.has_extension( + new_name + ): + Span.set_extension(new_name, default=None) + + if new_name: + value = True if value is None else value + if not self.keep_raw_attribute_values: + value = ( + True + if value in ("True", "true") + else False + if value in ("False", "false") + else value + ) + span._.set(new_name, value) + + entities.setdefault(ent["entity_id"], []).append(span) + spans.append(span) + + set_spans(doc, spans, span_setter=self.span_setter) + for attr, value in self.default_attributes.items(): + for span in spans: + if span._.get(attr) is None: + span._.set(attr, value) + + for relation in obj.get("relations", []): + relation_label = ( + relation["relation_label"] + if "relation_label" in relation + else relation["label"] + ) + from_entity_id = relation["from_entity_id"] + to_entity_id = relation["to_entity_id"] - set_spans(doc, spans, span_setter=self.span_setter) - for attr, value in self.default_attributes.items(): - for span in spans: - if span._.get(attr) is None: - span._.set(attr, value) + for head in entities.get(from_entity_id, ()): + for tail in entities.get(to_entity_id, ()): + head._.rel.setdefault(relation_label, set()).add(tail) + except Exception: + raise ValueError(f"Error when processing {note_id}") return doc diff --git a/edsnlp/data/standoff.py b/edsnlp/data/standoff.py index bcecbf4bf5..4bfe0d71b1 100644 --- a/edsnlp/data/standoff.py +++ b/edsnlp/data/standoff.py @@ -32,6 +32,7 @@ REGEX_ATTRIBUTE = re.compile(r"^([AM]\d+)\t(.+?) ([TE]\d+)(?: (.+))?$") REGEX_EVENT = re.compile(r"^(E\d+)\t(.+)$") REGEX_EVENT_PART = re.compile(r"(\S+):([TE]\d+)") +REGEX_STATUS = re.compile(r"^(#\d+)\tStatus ([^\t]+)\t(.*)$") class BratParsingError(ValueError): @@ -71,6 +72,7 @@ def parse_standoff_file( entities = {} relations = [] events = {} + doc = {} with fs.open(txt_path, "r", encoding="utf-8") as f: text = f.read() @@ -178,6 +180,11 @@ def parse_standoff_file( "arguments": arguments, } elif line.startswith("#"): + match = REGEX_STATUS.match(line) + if match: + comment = match.group(3) + doc["status"] = comment + continue match = REGEX_NOTE.match(line) if match is None: raise BratParsingError(ann_file, line) @@ -201,6 +208,7 @@ def parse_standoff_file( "entities": list(entities.values()), "relations": relations, "events": list(events.values()), + **doc, } @@ -260,19 +268,19 @@ def dump_standoff_file( ) attribute_idx += 1 - # fmt: off - # if "relations" in doc: - # for i, relation in enumerate(doc["relations"]): - # entity_from = entities_ids[relation["from_entity_id"]] - # entity_to = entities_ids[relation["to_entity_id"]] - # print( - # "R{}\t{} Arg1:{} Arg2:{}\t".format( - # i + 1, str(relation["label"]), entity_from, - # entity_to - # ), - # file=f, - # ) - # fmt: on + # fmt: off + if "relations" in doc: + for i, relation in enumerate(doc["relations"]): + entity_from = entities_ids[relation["from_entity_id"]] + entity_to = entities_ids[relation["to_entity_id"]] + print( + "R{}\t{} Arg1:{} Arg2:{}\t".format( + i + 1, str(relation["label"]), entity_from, + entity_to + ), + file=f, + ) + # fmt: on class StandoffReader(FileBasedReader): diff --git a/edsnlp/extensions.py b/edsnlp/extensions.py index 7127afe5b2..be8c871116 100644 --- a/edsnlp/extensions.py +++ b/edsnlp/extensions.py @@ -2,7 +2,7 @@ from datetime import date, datetime from dateutil.parser import parse as parse_date -from spacy.tokens import Doc +from spacy.tokens import Doc, Span if not Doc.has_extension("note_id"): Doc.set_extension("note_id", default=None) @@ -43,3 +43,6 @@ def get_note_datetime(doc): if not Doc.has_extension("birth_datetime"): Doc.set_extension("birth_datetime", default=None) + +if not Span.has_extension("rel"): + Span.set_extension("rel", default={}) diff --git a/edsnlp/metrics/relations.py b/edsnlp/metrics/relations.py new file mode 100644 index 0000000000..1dc2df8eef --- /dev/null +++ b/edsnlp/metrics/relations.py @@ -0,0 +1,140 @@ +from collections import defaultdict +from itertools import product +from typing import Any, Optional + +from edsnlp import registry +from edsnlp.metrics import Examples, make_examples, prf +from edsnlp.utils.span_getters import RelationCandidateGetter, get_spans +from edsnlp.utils.typing import AsList + + +def relations_scorer( + examples: Examples, + candidate_getter: AsList[RelationCandidateGetter], + micro_key: str = "micro", + filter_expr: Optional[str] = None, +): + """ + Scores the attributes predictions between a list of gold and predicted spans. + + Parameters + ---------- + examples : Examples + The examples to score, either a tuple of (golds, preds) or a list of + spacy.training.Example objects + candidate_getter : AsList[RelationCandidateGetter] + The candidate getters to use to extract the possible relations from the + documents. Each candidate getter should be a dictionary with the keys + "head", "tail", and "labels". The "head" and "tail" keys should be + SpanGetterArg objects, and the "labels" key should be a list of strings + for these head-tail pairs. + micro_key : str + The key to use to store the micro-averaged results for spans of all types + filter_expr : Optional[str] + The filter expression to use to filter the documents + + Returns + ------- + Dict[str, float] + """ + examples = make_examples(examples) + if filter_expr is not None: + filter_fn = eval(f"lambda doc: {filter_expr}") + examples = [eg for eg in examples if filter_fn(eg.reference)] + # annotations: {label -> preds, golds, pred_with_probs} + annotations = defaultdict(lambda: (set(), set(), dict())) + annotations[micro_key] = (set(), set(), dict()) + total_pred_count = 0 + total_gold_count = 0 + + for candidate in candidate_getter: + head_getter = candidate["head"] + tail_getter = candidate["tail"] + labels = candidate["labels"] + symmetric = candidate.get("symmetric") or False + label_filter = candidate.get("label_filter") + for eg_idx, eg in enumerate(examples): + pred_heads = [ + ((h.start, h.end, h.label_), h) + for h in get_spans(eg.predicted, head_getter) + ] + pred_tails = [ + ((t.start, t.end, t.label_), t) + for t in get_spans(eg.predicted, tail_getter) + ] + for (h_key, head), (t_key, tail) in product(pred_heads, pred_tails): + if label_filter is not None and ( + head.label_ not in label_filter or tail.label_ not in label_filter + ): + continue + total_pred_count += 1 + for label in labels: + if ( + tail in head._.rel.get(label, ()) + or symmetric + and head in tail._.rel.get(label, ()) + ): + if symmetric and h_key > t_key: + h_key, t_key = t_key, h_key + annotations[label][0].add((eg_idx, h_key, t_key, label)) + annotations[micro_key][0].add((eg_idx, h_key, t_key, label)) + + gold_heads = [ + ((h.start, h.end, h.label_), h) + for h in get_spans(eg.reference, head_getter) + ] + gold_tails = [ + ((t.start, t.end, t.label_), t) + for t in get_spans(eg.reference, tail_getter) + ] + for (h_key, head), (t_key, tail) in product(gold_heads, gold_tails): + total_gold_count += 1 + for label in labels: + if ( + tail in head._.rel.get(label, ()) + or symmetric + and head in tail._.rel.get(label, ()) + ): + if symmetric and h_key > t_key: + h_key, t_key = t_key, h_key + annotations[label][1].add((eg_idx, h_key, t_key, label)) + annotations[micro_key][1].add((eg_idx, h_key, t_key, label)) + + if total_pred_count != total_gold_count: + raise ValueError( + f"Number of predicted and gold candidate pairs differ: {total_pred_count} " + f"!= {total_gold_count}. Make sure that you are running your span " + "attribute classification pipe on the gold annotations, and not spans " + "predicted by another NER pipe in your model." + ) + + return { + name: { + **prf(pred, gold), + # "ap": average_precision(pred_with_prob, gold), + } + for name, (pred, gold, pred_with_prob) in annotations.items() + } + + +@registry.metrics.register("eds.relations") +class RelationsMetric: + def __init__( + self, + candidate_getter: AsList[RelationCandidateGetter], + micro_key: str = "micro", + filter_expr: Optional[str] = None, + ): + self.candidate_getter = candidate_getter + self.micro_key = micro_key + self.filter_expr = filter_expr + + __init__.__doc__ = relations_scorer.__doc__ + + def __call__(self, *examples: Any): + return relations_scorer( + examples, + candidate_getter=self.candidate_getter, + micro_key=self.micro_key, + filter_expr=self.filter_expr, + ) diff --git a/edsnlp/pipes/__init__.py b/edsnlp/pipes/__init__.py index aea3f0f088..f4aa79014c 100644 --- a/edsnlp/pipes/__init__.py +++ b/edsnlp/pipes/__init__.py @@ -75,6 +75,7 @@ from .qualifiers.reported_speech.factory import create_component as reported_speech from .qualifiers.reported_speech.factory import create_component as rspeech from .trainable.ner_crf.factory import create_component as ner_crf + from .trainable.relation_detector_ffn.factory import create_component as relation_detector_ffn from .trainable.biaffine_dep_parser.factory import create_component as biaffine_dep_parser from .trainable.extractive_qa.factory import create_component as extractive_qa from .trainable.span_classifier.factory import create_component as span_classifier diff --git a/edsnlp/pipes/base.py b/edsnlp/pipes/base.py index c66ff40b54..0883c1180d 100644 --- a/edsnlp/pipes/base.py +++ b/edsnlp/pipes/base.py @@ -14,6 +14,7 @@ from edsnlp.core import PipelineProtocol from edsnlp.core.registries import DraftPipe from edsnlp.utils.span_getters import ( + RelationCandidateGetter, SpanGetter, # noqa: F401 SpanGetterArg, # noqa: F401 SpanSetter, @@ -23,6 +24,7 @@ validate_span_getter, # noqa: F401 validate_span_setter, ) +from edsnlp.utils.typing import AsList def value_getter(span: Span): @@ -188,3 +190,37 @@ def qualifiers(self): # pragma: no cover @qualifiers.setter def qualifiers(self, value): # pragma: no cover self.attributes = value + + +class BaseRelationDetectorComponent(BaseComponent, abc.ABC): + def __init__( + self, + nlp: PipelineProtocol = None, + name: str = None, + *args, + candidate_getter: AsList[RelationCandidateGetter], + **kwargs, + ): + super().__init__(nlp, name, *args, **kwargs) + self.candidate_getter = [ + { + "head": validate_span_getter(candidate["head"]), + "tail": validate_span_getter(candidate["tail"]), + "labels": candidate["labels"], + "label_filter": { + head: set(tail_labels) + for head, tail_labels in candidate["label_filter"].items() + } + if candidate.get("label_filter") + else None, + "symmetric": candidate.get("symmetric") or False, + } + for candidate in candidate_getter + ] + self.labels = sorted( + { + label + for candidate in self.candidate_getter + for label in candidate["labels"] + } + ) diff --git a/edsnlp/pipes/trainable/embeddings/span_pooler/span_pooler.py b/edsnlp/pipes/trainable/embeddings/span_pooler/span_pooler.py index 5e58cd9bb6..7cb4abecfe 100644 --- a/edsnlp/pipes/trainable/embeddings/span_pooler/span_pooler.py +++ b/edsnlp/pipes/trainable/embeddings/span_pooler/span_pooler.py @@ -24,7 +24,10 @@ "embedding": BatchInput, "begins": ft.FoldedTensor, "ends": ft.FoldedTensor, - "sequence_idx": torch.Tensor, + "word_to_span_idx": torch.Tensor, + "span_to_ctx_idx": torch.Tensor, + "flat_indices": torch.Tensor, + "offsets": torch.Tensor, "stats": TypedDict("SpanPoolerBatchStats", {"spans": int}), }, ) @@ -35,8 +38,16 @@ Begin offsets of the spans ends: torch.LongTensor End offsets of the spans -sequence_idx: torch.LongTensor - Sequence (cf Embedding spans) index of the spans +word_to_span_idx: torch.LongTensor + Span index of each token in the flattened span tokens +span_to_ctx_idx: torch.LongTensor + Sequence/context (cf Embedding spans) index of the spans +flat_indices: torch.LongTensor + Indices of the tokens in the flattened span tokens +offsets: torch.LongTensor + Offsets of the spans in the flattened span tokens +stats: Dict[str, int] + Statistics about the batch, e.g. number of spans """ SpanPoolerBatchOutput = TypedDict( @@ -61,8 +72,14 @@ class SpanPooler(SpanEmbeddingComponent, BaseComponent): Name of the component embedding : WordEmbeddingComponent The word embedding component - pooling_mode: Literal["max", "sum", "mean"] - How word embeddings are aggregated into a single embedding per span. + pooling_mode: Literal["max", "sum", "mean", "attention"] + How word embeddings are aggregated into a single embedding per span: + + - "max": max pooling + - "sum": sum pooling + - "mean": mean pooling + - "attention": attention pooling, where attention scores are computed using a + linear layer followed by a softmax over the tokens in the span. hidden_size : Optional[int] The size of the hidden layer. If None, no projection is done and the output of the span pooler is used directly. @@ -74,7 +91,9 @@ def __init__( name: str = "span_pooler", *, embedding: WordEmbeddingComponent, - pooling_mode: Literal["max", "sum", "mean"] = "mean", + pooling_mode: Literal["max", "sum", "mean", "attention"] = "mean", + activation: Optional[str] = None, + norm: Optional[str] = None, hidden_size: Optional[int] = None, span_getter: Any = None, ): @@ -99,11 +118,35 @@ def __init__( self.pooling_mode = pooling_mode self.span_getter = span_getter self.embedding = embedding - self.projector = ( - torch.nn.Linear(self.embedding.output_size, hidden_size) - if hidden_size is not None - else torch.nn.Identity() - ) + self.activation = activation + self.projector = torch.nn.Sequential() + if hidden_size is not None: + self.projector.append( + torch.nn.Linear(self.embedding.output_size, hidden_size) + ) + if activation is not None: + self.projector.append( + { + "relu": torch.nn.ReLU, + "gelu": torch.nn.GELU, + "silu": torch.nn.SiLU, + }[activation]() + ) + if norm is not None: + self.projector.append( + { + "layernorm": torch.nn.LayerNorm, + "batchnorm": torch.nn.BatchNorm1d, + }[norm]( + hidden_size + if hidden_size is not None + else self.embedding.output_size + ) + ) + if self.pooling_mode in {"attention"}: + self.attention_scorer = torch.nn.Linear( + self.embedding.output_size, 1, bias=False + ) def feed_forward(self, span_embeds: torch.Tensor) -> torch.Tensor: return self.projector(span_embeds) @@ -119,7 +162,7 @@ def preprocess( ) -> Dict[str, Any]: contexts = contexts if contexts is not None else [doc[:]] - sequence_idx = [] + context_indices = [] begins = [] ends = [] @@ -140,44 +183,96 @@ def preprocess( f"span: {[s.text for s in ctx]}" ) start = ctx[0].start - sequence_idx.append(contexts_to_idx[ctx[0]]) + context_indices.append(contexts_to_idx[ctx[0]]) begins.append(span.start - start) ends.append(span.end - start) return { "begins": begins, "ends": ends, - "sequence_idx": sequence_idx, + "span_to_ctx_idx": context_indices, "num_sequences": len(contexts), "embedding": self.embedding.preprocess(doc, contexts=contexts, **kwargs), "stats": {"spans": len(begins)}, } def collate(self, batch: Dict[str, Sequence[Any]]) -> SpanPoolerBatchInput: - sequence_idx = [] - offset = 0 - for indices, seq_length in zip(batch["sequence_idx"], batch["num_sequences"]): - sequence_idx.extend([offset + idx for idx in indices]) - offset += seq_length + embedding_batch = self.embedding.collate(batch["embedding"]) + n_words = embedding_batch["stats"]["words"] + span_to_ctx_idx = [] + word_to_span_idx = [] + offset_ctx = 0 + offset_span = 0 + flat_indices = [] + offsets = [0] + for indices, num_sample_contexts, begins, ends in zip( + batch["span_to_ctx_idx"], + batch["num_sequences"], + batch["begins"], + batch["ends"], + ): + span_to_ctx_idx.extend([offset_ctx + idx for idx in indices]) + offset_ctx += num_sample_contexts + for b, e, ctx_idx in zip(begins, ends, indices): + offset_word = n_words * ctx_idx + word_to_span_idx.extend([offset_span] * (e - b)) + flat_indices.extend(range(offset_word + b, offset_word + e)) + offsets.append(len(flat_indices)) + offset_span += 1 + offsets = offsets[:-1] + begins = ft.as_folded_tensor( + batch["begins"], + data_dims=("span",), + full_names=("sample", "span"), + dtype=torch.long, + ) + ends = ft.as_folded_tensor( + batch["ends"], + data_dims=("span",), + full_names=("sample", "span"), + dtype=torch.long, + ) collated: SpanPoolerBatchInput = { - "embedding": self.embedding.collate(batch["embedding"]), - "begins": ft.as_folded_tensor( - batch["begins"], - data_dims=("span",), - full_names=("sample", "span"), - dtype=torch.long, - ), - "ends": ft.as_folded_tensor( - batch["ends"], - data_dims=("span",), - full_names=("sample", "span"), - dtype=torch.long, - ), - "sequence_idx": torch.as_tensor(sequence_idx), + "embedding": embedding_batch, + "begins": begins, + "ends": ends, + "flat_indices": torch.as_tensor(flat_indices), # (num_span_tokens,) + "offsets": torch.as_tensor(offsets), # (num_spans,) + "word_to_span_idx": torch.as_tensor(word_to_span_idx), # (num_span_tokens,) + "span_to_ctx_idx": torch.as_tensor(span_to_ctx_idx), # (num_spans,) "stats": {"spans": sum(batch["stats"]["spans"])}, } return collated + def _pool_spans(self, flat_embeds, word_to_span_idx, offsets): + dev = offsets.device + dim = flat_embeds.size(-1) + n_spans = len(offsets) + + if self.pooling_mode == "attention": + weights = self.attention_scorer(flat_embeds) + # compute max for softmax stability + max_weights = torch.full((n_spans, 1), float("-inf"), device=dev) + max_weights.index_reduce_(0, word_to_span_idx, weights, reduce="amax") + # softmax numerator + exp_scores = torch.exp(weights - max_weights[word_to_span_idx]) + # softmax denominator + denom = torch.zeros((n_spans, 1), device=dev) + denom.index_add_(0, word_to_span_idx, exp_scores) + # softmax output = embeds * weight num / weight denom + weighted_embeds = flat_embeds * exp_scores / denom[word_to_span_idx] + span_embeds = torch.zeros((n_spans, dim), device=dev) + span_embeds.index_add_(0, word_to_span_idx, weighted_embeds) + else: + span_embeds = torch.nn.functional.embedding_bag( # type: ignore + input=torch.arange(len(flat_embeds), device=dev), + weight=flat_embeds, + offsets=offsets, + mode=self.pooling_mode, + ) + span_embeds = self.feed_forward(span_embeds) + return span_embeds + # noinspection SpellCheckingInspection def forward(self, batch: SpanPoolerBatchInput) -> SpanPoolerBatchOutput: """ @@ -196,36 +291,25 @@ def forward(self, batch: SpanPoolerBatchInput) -> SpanPoolerBatchOutput: ------- BatchOutput """ - device = next(self.parameters()).device - if len(batch["begins"]) == 0: - span_embeds = torch.empty(0, self.output_size, device=device) + n_spans = len(batch["begins"]) + word_to_span_idx = batch["word_to_span_idx"] + offsets = batch["offsets"] + flat_indices = batch["flat_indices"] + + if n_spans == 0: + span_embeds = torch.empty(0, self.output_size, device=offsets.dev) return { "embeddings": batch["begins"].with_data(span_embeds), } embeds = self.embedding(batch["embedding"])["embeddings"] - _, n_words, dim = embeds.shape - device = embeds.device - - flat_begins = n_words * batch["sequence_idx"] + batch["begins"].as_tensor() - flat_ends = n_words * batch["sequence_idx"] + batch["ends"].as_tensor() - flat_embeds = embeds.view(-1, dim) - flat_indices = torch.cat( - [ - torch.arange(b, e, device=device) - for b, e in zip(flat_begins.cpu().tolist(), flat_ends.cpu().tolist()) - ] - ).to(device) - offsets = (flat_ends - flat_begins).cumsum(0).roll(1) - offsets[0] = 0 - span_embeds = torch.nn.functional.embedding_bag( # type: ignore - input=flat_indices, - weight=flat_embeds, - offsets=offsets, - mode=self.pooling_mode, + embeds = embeds.refold(["context", "word"]) + flat_embeds = embeds.view(-1, embeds.size(-1))[flat_indices] + span_embeds = self._pool_spans( + flat_embeds, + word_to_span_idx, + offsets, ) - span_embeds = self.feed_forward(span_embeds) - return { "embeddings": batch["begins"].with_data(span_embeds), } diff --git a/edsnlp/pipes/trainable/relation_detector_ffn/__init__.py b/edsnlp/pipes/trainable/relation_detector_ffn/__init__.py new file mode 100644 index 0000000000..549d2fc779 --- /dev/null +++ b/edsnlp/pipes/trainable/relation_detector_ffn/__init__.py @@ -0,0 +1 @@ +from .factory import create_component diff --git a/edsnlp/pipes/trainable/relation_detector_ffn/factory.py b/edsnlp/pipes/trainable/relation_detector_ffn/factory.py new file mode 100644 index 0000000000..066f877043 --- /dev/null +++ b/edsnlp/pipes/trainable/relation_detector_ffn/factory.py @@ -0,0 +1,9 @@ +from edsnlp import registry + +from .relation_detector_ffn import RelationDetectorFFN + +create_component = registry.factory.register( + "eds.relation_detector_ffn", + assigns=[], + deprecated=[], +)(RelationDetectorFFN) diff --git a/edsnlp/pipes/trainable/relation_detector_ffn/relation_detector_ffn.py b/edsnlp/pipes/trainable/relation_detector_ffn/relation_detector_ffn.py new file mode 100644 index 0000000000..db94a1b2ec --- /dev/null +++ b/edsnlp/pipes/trainable/relation_detector_ffn/relation_detector_ffn.py @@ -0,0 +1,386 @@ +from __future__ import annotations + +import logging +import warnings +from collections import defaultdict +from itertools import product +from typing import ( + Any, + Dict, + Iterable, + List, + Optional, + Set, +) + +import torch +import torch.nn.functional as F +from spacy.tokens import Doc, Span +from typing_extensions import TypedDict + +from edsnlp.core import PipelineProtocol +from edsnlp.core.torch_component import BatchInput, BatchOutput, TorchComponent +from edsnlp.pipes.base import BaseRelationDetectorComponent +from edsnlp.pipes.trainable.embeddings.typing import ( + SpanEmbeddingComponent, + WordEmbeddingComponent, +) +from edsnlp.utils.span_getters import RelationCandidateGetter, get_spans +from edsnlp.utils.typing import AsList + + +def make_ranges(starts, ends): + """ + Efficient computation and concat of ranges from starts and ends. + + Examples + -------- + ```{ .python .no-check } + + starts = torch.tensor([0, 3, 6]) + ends = torch.tensor([2, 8, 8]) + make_ranges(starts, ends) + # <---> <-----------> <---> + # tensor([0, 1, 3, 4, 5, 6, 7, 6, 7]) + ``` + + Parameters + ---------- + starts: torch.Tensor + ends: torch.Tensor + + Returns + ------- + torch.Tensor + """ + assert starts.shape == ends.shape + if 0 in ends.shape: + return ends + sizes = ends - starts + mask = sizes > 0 + offsets = sizes.cumsum(0) + offsets = offsets.roll(1) + res = torch.ones(offsets[0], dtype=torch.long, device=starts.device) + offsets[0] = 0 + masked_offsets = offsets[mask] + starts = starts[mask] + ends = ends[mask] + res[masked_offsets] = starts + res[masked_offsets[1:]] -= ends[:-1] - 1 + return res.cumsum(0), offsets + + +logger = logging.getLogger(__name__) + +FrameBatchInput = TypedDict( + "FrameBatchInput", + { + "span_embedding": BatchInput, + "word_embedding": BatchInput, + "rel_head_idx": torch.Tensor, + "rel_tail_idx": torch.Tensor, + "rel_doc_idx": torch.Tensor, + "rel_labels": torch.Tensor, + }, +) +""" +span_embedding: torch.FloatTensor + Token embeddings to predict the tags from +""" + + +class MLP(torch.nn.Module): + def __init__( + self, input_dim: int, hidden_dim: int, output_dim: int, dropout_p: float = 0.0 + ): + super().__init__() + self.hidden = torch.nn.Linear(input_dim, hidden_dim) + self.output = torch.nn.Linear(hidden_dim, output_dim) + self.dropout = torch.nn.Dropout(dropout_p) + + def forward(self, x): + x = self.dropout(x) + x = self.hidden(x) + x = F.gelu(x) + x = self.output(x) + return x + + +class RelationDetectorFFN( + TorchComponent[BatchOutput, FrameBatchInput], + BaseRelationDetectorComponent, +): + def __init__( + self, + nlp: Optional[PipelineProtocol] = None, + name: str = "relation_detector_ffn", + *, + span_embedding: SpanEmbeddingComponent, + inter_span_embedding: Optional[SpanEmbeddingComponent] = None, + word_embedding: WordEmbeddingComponent, + candidate_getter: AsList[RelationCandidateGetter], + hidden_size: int = 128, + dropout_p: float = 0.0, + use_inter_words: bool = True, + ): + super().__init__( + nlp=nlp, + name=name, + candidate_getter=candidate_getter, + ) + self.span_embedding = span_embedding + self.inter_span_embedding = inter_span_embedding or span_embedding + self.word_embedding = word_embedding + self.use_inter_words = use_inter_words + + embed_size = self.span_embedding.output_size * 2 + ( + self.word_embedding.output_size if use_inter_words else 0 + ) + with warnings.catch_warnings(): + warnings.simplefilter("ignore", UserWarning) + # self.head_projection = torch.nn.Linear(hidden_size, hidden_size) + # self.tail_projection = torch.nn.Linear(hidden_size, hidden_size) + self.mlp = MLP(embed_size, hidden_size, hidden_size, dropout_p) + self.classifier = torch.nn.Linear(hidden_size, len(self.labels)) + + @property + def span_getter(self): + return self.embedding.span_getter + + def to_disk(self, path, *, exclude=set()): + repr_id = object.__repr__(self) + if repr_id in exclude: + return + return super().to_disk(path, exclude=exclude) + + def from_disk(self, path, exclude=tuple()): + repr_id = object.__repr__(self) + if repr_id in exclude: + return + self.set_extensions() + super().from_disk(path, exclude=exclude) + + def set_extensions(self): + super().set_extensions() + if not Span.has_extension("rel"): + Span.set_extension("rel", default={}) + if not Span.has_extension("scope"): + Span.set_extension("scope", default=None) + + def post_init(self, gold_data: Iterable[Doc], exclude: Set[str]): + super().post_init(gold_data, exclude=exclude) + + def preprocess(self, doc: Doc, supervised: int = False) -> Dict[str, Any]: + rel_head_idx = [] + rel_tail_idx = [] + rel_labels = [] + rel_getter_indices = [] + + all_spans = defaultdict(lambda: len(all_spans)) + + for getter_idx, getter in enumerate(self.candidate_getter): + head_spans = list(get_spans(doc, getter["head"])) + tail_spans = list(get_spans(doc, getter["tail"])) + lab_filter = getter.get("label_filter") + for head, tail in product(head_spans, tail_spans): + if lab_filter and head in lab_filter and tail not in lab_filter[head]: + continue + rel_head_idx.append(all_spans[head]) + rel_tail_idx.append(all_spans[tail]) + rel_getter_indices.append(getter_idx) + if supervised: + rel_labels.append( + [ + ( + tail in head._.rel.get(lab, ()) + or ( + getter["symmetric"] + and head in tail._.rel.get(lab, ()) + ) + ) + for lab in self.labels + ] + ) + + result = { + "num_spans": len(all_spans), + "rel_heads": rel_head_idx, + "rel_tails": rel_tail_idx, + "word_embedding": self.word_embedding.preprocess(doc, contexts=None), + "span_embedding": self.span_embedding.preprocess( + doc, + spans=list(all_spans), + contexts=None, + ), + "$spans": list(all_spans.keys()), + "$getter": rel_getter_indices, + "stats": { + "relation_candidates": len(rel_head_idx), + }, + } + if supervised: + result["rel_labels"] = rel_labels + + return result + + def preprocess_supervised(self, doc: Doc) -> Dict[str, Any]: + return self.preprocess(doc, supervised=True) + + def collate(self, batch: Dict[str, Any]) -> FrameBatchInput: + rel_heads = [] + rel_tails = [] + rel_doc_idx = [] + offset = 0 + for doc_idx, feats in enumerate( + zip( + batch["rel_heads"], + batch["rel_tails"], + batch["num_spans"], + ) + ): + doc_rel_heads, doc_rel_tails, doc_num_spans = feats + rel_heads.extend([x + offset for x in doc_rel_heads]) + rel_tails.extend([x + offset for x in doc_rel_tails]) + rel_doc_idx.extend([doc_idx] * len(doc_rel_heads)) + offset += batch["num_spans"][doc_idx] + + collated: FrameBatchInput = { + "rel_head_idx": torch.as_tensor(rel_heads, dtype=torch.long), + "rel_tail_idx": torch.as_tensor(rel_tails, dtype=torch.long), + "rel_doc_idx": torch.as_tensor(rel_doc_idx, dtype=torch.long), + "span_embedding": self.span_embedding.collate(batch["span_embedding"]), + "word_embedding": self.word_embedding.collate(batch["word_embedding"]), + "stats": {"relation_candidates": len(rel_heads)}, + } + + if "rel_labels" in batch: + collated["rel_labels"] = torch.as_tensor( + [labs for doc_labels in batch["rel_labels"] for labs in doc_labels] + ).view(-1, self.classifier.out_features) + return collated + + def compute_inter_span_embeds(self, word_embeds, begins, ends, head_idx, tail_idx): + _, n_words, dim = word_embeds.shape + if 0 in begins.shape or 0 in head_idx.shape: + return torch.zeros( + 0, dim, dtype=word_embeds.dtype, device=word_embeds.device + ) + + flat_begins = torch.minimum(ends[head_idx], ends[tail_idx]) + flat_ends = torch.maximum(begins[head_idx], begins[tail_idx]) + flat_begins, flat_ends = ( + torch.minimum(flat_begins, flat_ends), + torch.maximum(flat_begins, flat_ends), + ) + lengths = flat_ends - flat_begins + word_to_span_idx = torch.arange( + len(head_idx), device=word_embeds.device + ).repeat_interleave(lengths) + flat_indices, flat_offsets = make_ranges(flat_begins, flat_ends) + flat_offsets[0] = 0 + inter_span_embeds = self.inter_span_embedding._pool_spans( + flat_embeds=word_embeds.view(-1, dim)[flat_indices], + word_to_span_idx=word_to_span_idx, + offsets=flat_offsets, + ) + return inter_span_embeds + + # noinspection SpellCheckingInspection + def forward(self, batch: FrameBatchInput) -> BatchOutput: + """ + Apply the span classifier module to the document embeddings and given spans to: + - compute the loss + - and/or predict the labels of spans + + Parameters + ---------- + batch: SpanQualifierBatchInput + The input batch + + Returns + ------- + BatchOutput + """ + word_embeds = self.word_embedding(batch["word_embedding"])["embeddings"] + span_embeds = self.span_embedding(batch["span_embedding"])["embeddings"] + + n_words = word_embeds.size(-2) + spans = batch["span_embedding"] + flat_begins = n_words * spans["span_to_ctx_idx"] + spans["begins"].as_tensor() + flat_ends = n_words * spans["span_to_ctx_idx"] + spans["ends"].as_tensor() + if self.use_inter_words: + inter_span_embeds = self.compute_inter_span_embeds( + word_embeds=word_embeds, + begins=flat_begins, + ends=flat_ends, + head_idx=batch["rel_head_idx"], + tail_idx=batch["rel_tail_idx"], + ) + rel_embeds = torch.cat( + [ + span_embeds[batch["rel_head_idx"]], + inter_span_embeds, + span_embeds[batch["rel_tail_idx"]], + ], + dim=-1, + ) + else: + rel_embeds = torch.cat( + [ + span_embeds[batch["rel_head_idx"]], + span_embeds[batch["rel_tail_idx"]], + ], + dim=-1, + ) + rel_embeds = self.mlp(rel_embeds) + logits = self.classifier(rel_embeds) + + losses = pred = None + if "rel_labels" in batch: + losses = [] + target = batch["rel_labels"].float() + num_relation_candidates = batch["stats"]["relation_candidates"] + losses.append( + F.binary_cross_entropy_with_logits(logits, target, reduction="sum") + / num_relation_candidates + ) + else: + pred = logits > 0 + + return { + "loss": sum(losses) if losses is not None else None, + "pred": pred, + } + + def postprocess( + self, + docs: List[Doc], + results: BatchOutput, + inputs: List[Dict[str, Any]], + ): + """ + Extract predicted relations from forward's "pred" field (boolean tensor) + and annotated them on the head._.rel attribute (dictionary) + Parameters + ---------- + docs: Sequence[Doc] + List of documents to update + results: BatchOutput + Batch of predictions, as returned by the forward method + inputs: BatchInput + List of preprocessed features, as returned by the preprocess method + + Returns + ------- + """ + all_heads = [p["$spans"][idx] for p in inputs for idx in p["rel_heads"]] + all_tails = [p["$spans"][idx] for p in inputs for idx in p["rel_tails"]] + getter_indices = [idx for p in inputs for idx in p["$getter"]] + for pair_idx, label_idx in results["pred"].nonzero(as_tuple=False).tolist(): + head = all_heads[pair_idx] + tail = all_tails[pair_idx] + label = self.labels[label_idx] + head._.rel.setdefault(label, set()).add(tail) + if self.candidate_getter[getter_indices[pair_idx]]["symmetric"]: + tail._.rel.setdefault(label, set()).add(head) + return docs diff --git a/edsnlp/training/trainer.py b/edsnlp/training/trainer.py index 2400aa2b12..42a57f31ae 100644 --- a/edsnlp/training/trainer.py +++ b/edsnlp/training/trainer.py @@ -35,6 +35,7 @@ from edsnlp.metrics.span_attribute import SpanAttributeMetric from edsnlp.pipes.base import ( BaseNERComponent, + BaseRelationDetectorComponent, BaseSpanAttributeClassifierComponent, ) from edsnlp.utils.batching import BatchSizeArg, stat_batchify @@ -50,6 +51,7 @@ from edsnlp.utils.typing import AsList from ..core.torch_component import TorchComponent +from ..metrics.relations import RelationsMetric from .optimizer import LinearSchedule, ScheduledOptimizer @@ -189,6 +191,35 @@ def __call__(self, nlp: Pipeline, docs: Iterable[Any]): for name, metric in span_attr_metrics.items(): scores[name] = metric(docs, qlf_preds) + # Relations + rel_pipes = [ + name + for name, pipe in nlp.pipeline + if isinstance(pipe, BaseRelationDetectorComponent) + ] + rel_metrics: Dict[str, RelationsMetric] = { # type: ignore + name: metrics.pop(name) + for name in list(metrics) + if isinstance(metrics[name], RelationsMetric) + } + if rel_pipes and rel_metrics: + clean_rel_docs = [d.copy() for d in tqdm(docs, desc="Copying docs")] + for doc in clean_rel_docs: + for name in rel_pipes: + pipe: BaseRelationDetectorComponent = nlp.get_pipe(name) # type: ignore + for candidate_getter in pipe.candidate_getter: + for span in ( + *get_spans(doc, candidate_getter["head"]), + *get_spans(doc, candidate_getter["tail"]), + ): + for label in pipe.labels: + if label in span._.rel: + span._.rel[label].clear() + with nlp.select_pipes(disable=ner_pipes): + rel_preds = list(nlp.pipe(tqdm(clean_rel_docs, desc="Predicting"))) + for name, scorer in rel_metrics.items(): + scores[name] = scorer(docs, rel_preds) + # Custom metrics for name, metric in metrics.items(): pred_docs = [d.copy() for d in tqdm(docs, desc="Copying docs")] diff --git a/edsnlp/utils/span_getters.py b/edsnlp/utils/span_getters.py index 74e325a769..778ccac42a 100644 --- a/edsnlp/utils/span_getters.py +++ b/edsnlp/utils/span_getters.py @@ -9,6 +9,7 @@ List, Optional, Sequence, + Set, Tuple, Union, ) @@ -16,6 +17,7 @@ import numpy as np from pydantic import NonNegativeInt from spacy.tokens import Doc, Span +from typing_extensions import NotRequired, TypedDict from edsnlp import registry from edsnlp.utils.filter import filter_spans @@ -49,6 +51,10 @@ def get_spans(doclike, span_getter, deduplicate=True): if isinstance(doclike, Doc): if k == "*": candidates = (s for grp in doclike.spans.values() for s in grp) + elif k == "ents": + candidates = doclike.ents + elif k == "doc": + candidates = (doclike[:],) else: candidates = doclike.spans.get(k, ()) if k != "ents" else doclike.ents else: @@ -60,10 +66,14 @@ def get_spans(doclike, span_getter, deduplicate=True): for s in grp if not (s.end < doclike.start or s.start > doclike.end) ) + elif k == "ents": + candidates = doclike.ents + elif k == "doc": + candidates = (doclike[:],) else: candidates = ( s - for s in (doc.spans.get(k, ()) if k != "ents" else doc.ents) + for s in (doc.spans.get(k, ())) if not (s.end < doclike.start or s.start > doclike.end) ) for span in candidates: @@ -86,8 +96,12 @@ def get_spans_with_group(doc, span_getter): candidates = ( (span, group) for group in doc.spans.values() for span in group ) + elif key == "ents": + candidates = ((span, key) for span in doc.ents) + elif key == "doc": + candidates = ((doc[:], "doc"),) else: - candidates = doc.spans.get(key, ()) if key != "ents" else doc.ents + candidates = doc.spans.get(key, ()) candidates = ((span, key) for span in candidates) if span_filter is True: yield from candidates @@ -548,3 +562,15 @@ def __call__(self, span): def __repr__(self): return " & ".join(repr(context) for context in self.contexts) + + +RelationCandidateGetter = TypedDict( + "RelationCandidateGetter", + { + "head": SpanGetterArg, + "tail": SpanGetterArg, + "labels": AsList[str], + "label_filter": NotRequired[Optional[Dict[str, Set[str]]]], + "symmetric": Optional[bool], + }, +) diff --git a/pyproject.toml b/pyproject.toml index 2c8ddc6520..c2e53ef11d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -43,7 +43,18 @@ dependencies = [ "pydantic<2.0.0; python_version<'3.8'", "pydantic-core<2.0.0; python_version<'3.8'", ] -[project.optional-dependencies] +[optional-dependencies] +ml = [ + "rich-logger>=0.3.1", + "torch>=1.13.0; python_version>='3.9'", + "foldedtensor>=0.4.0", + "safetensors>=0.3.0; python_version>='3.8'", + "safetensors>=0.3.0,<0.5.0; python_version<'3.8'", + "transformers>=4.0.0", + "accelerate>=0.20.3", +] + +[dependency-groups] dev-no-ml = [ "pre-commit>=2.0.0; python_version<'3.8'", "pre-commit>=2.21.0; python_version>='3.8'", @@ -67,15 +78,6 @@ docs-no-ml = [ "mkdocs-eds @ git+https://github.com/percevalw/mkdocs-eds.git@main#egg=mkdocs-eds ; python_version>='3.9'", "markdown-grid-tables==0.4.0; python_version>='3.9'", ] -ml = [ - "rich-logger>=0.3.1", - "torch>=1.13.0; python_version>='3.9'", - "foldedtensor>=0.4.0", - "safetensors>=0.3.0; python_version>='3.8'", - "safetensors>=0.3.0,<0.5.0; python_version<'3.8'", - "transformers>=4.0.0", - "accelerate>=0.20.3", -] docs = [ "edsnlp[docs-no-ml]", "edsnlp[ml]", @@ -94,6 +96,9 @@ setup = [ "mlconjug3<3.9.0", # bug https://github.com/Ars-Linguistica/mlconjug3/pull/506 "numpy<2", # mlconjug has scikit-learn dep which doesn't support for numpy 2 yet ] +ml = [ + "edsnlp[ml]", +] [project.urls] "Source Code" = "https://github.com/aphp/edsnlp" @@ -249,16 +254,17 @@ where = ["."] # edsnlp will look both in the above dict and in the one below. [project.entry-points."edsnlp_factories"] # Trainable -"eds.transformer" = "edsnlp.pipes.trainable.embeddings.transformer.factory:create_component" -"eds.text_cnn" = "edsnlp.pipes.trainable.embeddings.text_cnn.factory:create_component" -"eds.span_pooler" = "edsnlp.pipes.trainable.embeddings.span_pooler.factory:create_component" -"eds.ner_crf" = "edsnlp.pipes.trainable.ner_crf.factory:create_component" -"eds.extractive_qa" = "edsnlp.pipes.trainable.extractive_qa.factory:create_component" -"eds.nested_ner" = "edsnlp.pipes.trainable.ner_crf.factory:create_component" -"eds.span_qualifier" = "edsnlp.pipes.trainable.span_classifier.factory:create_component" -"eds.span_classifier" = "edsnlp.pipes.trainable.span_classifier.factory:create_component" -"eds.span_linker" = "edsnlp.pipes.trainable.span_linker.factory:create_component" -"eds.biaffine_dep_parser" = "edsnlp.pipes.trainable.biaffine_dep_parser.factory:create_component" +"eds.transformer" = "edsnlp.pipes.trainable.embeddings.transformer.factory:create_component" +"eds.text_cnn" = "edsnlp.pipes.trainable.embeddings.text_cnn.factory:create_component" +"eds.span_pooler" = "edsnlp.pipes.trainable.embeddings.span_pooler.factory:create_component" +"eds.ner_crf" = "edsnlp.pipes.trainable.ner_crf.factory:create_component" +"eds.extractive_qa" = "edsnlp.pipes.trainable.extractive_qa.factory:create_component" +"eds.nested_ner" = "edsnlp.pipes.trainable.ner_crf.factory:create_component" +"eds.span_qualifier" = "edsnlp.pipes.trainable.span_classifier.factory:create_component" +"eds.span_classifier" = "edsnlp.pipes.trainable.span_classifier.factory:create_component" +"eds.span_linker" = "edsnlp.pipes.trainable.span_linker.factory:create_component" +"eds.biaffine_dep_parser" = "edsnlp.pipes.trainable.biaffine_dep_parser.factory:create_component" +"eds.relation_detector_ffn" = "edsnlp.pipes.trainable.relation_detector_ffn.factory:create_component" [project.entry-points."edsnlp_schedules"] "linear" = "edsnlp.training.optimizer:LinearSchedule" @@ -269,6 +275,7 @@ where = ["."] "eds.ner_overlap" = "edsnlp.metrics.ner:NerOverlapMetric" "eds.span_attribute" = "edsnlp.metrics.span_attribute:SpanAttributeMetric" "eds.dep_parsing" = "edsnlp.metrics.dep_parsing:DependencyParsingMetric" +"eds.relations" = "edsnlp.metrics.relations:RelationsMetric" # Deprecated "eds.ner_exact_metric" = "edsnlp.metrics.ner:NerExactMetric" diff --git a/tests/data/test_converters.py b/tests/data/test_converters.py index c53ca1ef73..9e5d42396e 100644 --- a/tests/data/test_converters.py +++ b/tests/data/test_converters.py @@ -83,6 +83,14 @@ def test_read_standoff_dict(blank_nlp): "label": "test", }, ], + "relations": [ + { + "relation_id": "R1", + "relation_label": "linked", + "from_entity_id": 1, + "to_entity_id": 0, + } + ], } doc = get_dict2doc_converter( "standoff", @@ -98,6 +106,7 @@ def test_read_standoff_dict(blank_nlp): assert doc.ents[0].text == "This" assert doc.ents[0]._.negation is True assert doc.ents[1]._.negation is False + assert doc.ents[1]._.rel["linked"] == {doc.ents[0]} def test_write_omop_dict(blank_nlp): diff --git a/tests/training/dataset_2/sample-1.ann b/tests/training/dataset_2/sample-1.ann new file mode 100644 index 0000000000..d5ee1745fa --- /dev/null +++ b/tests/training/dataset_2/sample-1.ann @@ -0,0 +1,6 @@ +#1000 Status #1000 CHECKED +T1 date 6 18 19 juin 1987 +T2 covid 52 57 covid +T3 date 69 84 12 octobre 1983 +T3 covid 103 108 covid +R1 linked Arg1:T2 Arg2:T1 diff --git a/tests/training/dataset_2/sample-1.txt b/tests/training/dataset_2/sample-1.txt new file mode 100644 index 0000000000..e332852e86 --- /dev/null +++ b/tests/training/dataset_2/sample-1.txt @@ -0,0 +1,2 @@ +CR du 19 juin 1987. La patiente a été diagnostiquée covid positif le 12 octobre 1983. +Autre occurrence covid mentionnée sans date précise. diff --git a/tests/training/dataset_2/sample-2.ann b/tests/training/dataset_2/sample-2.ann new file mode 100644 index 0000000000..bd4ab3c7e5 --- /dev/null +++ b/tests/training/dataset_2/sample-2.ann @@ -0,0 +1,4 @@ +T1 date 11 24 29 avril 2020 +T2 date 62 70 30 avril +T3 covid 101 106 covid +R1 linked Arg1:T3 Arg2:T1 diff --git a/tests/training/dataset_2/sample-2.txt b/tests/training/dataset_2/sample-2.txt new file mode 100644 index 0000000000..9d0c579b50 --- /dev/null +++ b/tests/training/dataset_2/sample-2.txt @@ -0,0 +1 @@ +On est le 29 avril 2020, et j'ai rendez vous à l'aéroport le 30 avril et je n'ai toujours pas eu le covid, je croise les doigts ! diff --git a/tests/training/ner_qlf_same_bert_config.yml b/tests/training/ner_qlf_same_bert_config.yml index 89544f5640..a57b2a5ad2 100644 --- a/tests/training/ner_qlf_same_bert_config.yml +++ b/tests/training/ner_qlf_same_bert_config.yml @@ -36,6 +36,7 @@ nlp: embedding: '@factory': eds.span_pooler + pooling_mode: attention embedding: # ${ nlp.components.ner.embedding } '@factory': eds.text_cnn diff --git a/tests/training/rel_config.cfg b/tests/training/rel_config.cfg new file mode 100644 index 0000000000..7a2dc2ea70 --- /dev/null +++ b/tests/training/rel_config.cfg @@ -0,0 +1,79 @@ +[nlp] +lang = "eds" +pipeline = [ + "normalizer", + "sentencizer", + "covid", + "dates", + "relations", + ] +batch_size = 2 +components = ${components} +tokenizer = {"@tokenizers": "eds.tokenizer"} + +[components.normalizer] +@factory = "eds.normalizer" + +[components.sentencizer] +@factory = "eds.sentences" + +[components.covid] +@factory = "eds.covid" + +[components.dates] +@factory = "eds.dates" + +# Relations component is: +# - a span relation detector, that classifies pairs spans embedded by... +# - a span pooler, that pools words embedded by... +# - a text cnn, that re-contextualizes words embedded by... +# - a transformer +[components.relations] +@factory = "eds.relation_detector_ffn" +head_getter = {"ents": "covid"} +tail_getter = {"ents": "date"} +labels = ["linked"] +symmetric = true + +[components.relations.word_embedding] +@factory = "eds.text_cnn" +kernel_sizes = [3] + +[components.relations.word_embedding.embedding] +@factory = "eds.transformer" +model = "hf-internal-testing/tiny-bert" +window = 128 +stride = 96 + +[components.relations.span_embedding] +@factory = "eds.span_pooler" +embedding = ${components.relations.word_embedding} + +[scorer.rel] +@metrics = "eds.relations" +head_getter = ${components.relations.head_getter} +tail_getter = ${components.relations.tail_getter} +labels = ${components.relations.labels} + +[train] +nlp = ${nlp} +max_steps = 50 +validation_interval = ${train.max_steps//10} +warmup_rate = 0 +batch_size = 2 samples +transformer_lr = 0 +task_lr = 1e-3 +scorer = ${scorer} + +[train.train_data] +randomize = true +max_length = 100 +multi_sentence = true +[train.train_data.reader] +@readers = "standoff" +path = "./dataset_2/" + +[train.val_data] +[train.val_data.reader] +@readers = "standoff" +path = "./dataset_2/" diff --git a/tests/training/rel_config.yml b/tests/training/rel_config.yml new file mode 100644 index 0000000000..1bdc3fd014 --- /dev/null +++ b/tests/training/rel_config.yml @@ -0,0 +1,90 @@ +# 🤖 PIPELINE DEFINITION +nlp: + "@core": pipeline + + lang: eds + + components: + normalizer: + '@factory': eds.normalizer + + sentencizer: + '@factory': eds.sentences + + covid: + '@factory': eds.covid + + relations: + '@factory': "eds.relation_detector_ffn" + candidate_getter: + head: { "ents": "covid" } + tail: { "ents": "date" } + labels: [ "linked" ] + symmetric: true + + word_embedding: + '@factory': eds.text_cnn + kernel_sizes: [ 3 ] + + embedding: + '@factory': eds.transformer + model: hf-internal-testing/tiny-bert + window: 128 + stride: 96 + + span_embedding: + '@factory': eds.span_pooler + embedding: ${nlp.components.relations.word_embedding} + +# 📈 SCORERS +scorer: + speed: true + batch_size: 2 docs + rel: + "@metrics": eds.relations + candidate_getter: ${nlp.components.relations.candidate_getter} + +# 🎛️ OPTIMIZER +# (disabled to test the default optimizer) +# optimizer: +# "@optimizers": adam +# groups: +# "*.transformer.*": +# lr: 1e-3 +# schedules: +# "@schedules": linear +# "warmup_rate": 0.1 +# "start_value": 0 +# "*": +# lr: 1e-3 +# schedules: +# "@schedules": linear +# "warmup_rate": 0.1 +# "start_value": 1e-3 + +# 📚 DATA +train_data: + - data: + '@readers': standoff + path: ./dataset_2/ + converter: + - '@factory': eds.standoff_dict2doc + shuffle: dataset + batch_size: 1 docs + +val_data: + '@readers': standoff + path: ./dataset_2/ + converter: + - '@factory': eds.standoff_dict2doc + +# 🚀 TRAIN SCRIPT OPTIONS +train: + nlp: ${ nlp } + train_data: ${ train_data } + val_data: ${ val_data } + max_steps: 40 + validation_interval: 10 + max_grad_norm: 1.0 + scorer: ${ scorer } + num_workers: 1 diff --git a/tests/training/test_train.py b/tests/training/test_train.py index c5abe5a855..42db5e9d77 100644 --- a/tests/training/test_train.py +++ b/tests/training/test_train.py @@ -208,6 +208,22 @@ def test_dep_parser_train(run_in_test_dir, tmp_path): assert last_scores["dep"]["las"] >= 0.4 +def test_rel_train(run_in_test_dir, tmp_path): + set_seed(42) + config = Config.from_disk("rel_config.yml") + shutil.rmtree(tmp_path, ignore_errors=True) + kwargs = Config.resolve(config["train"], registry=registry, root=config) + nlp = train(**kwargs, output_dir=tmp_path, cpu=True) + scorer = GenericScorer(**kwargs["scorer"]) + val_data = kwargs["val_data"] + last_scores = scorer(nlp, val_data) + + # Check empty doc + nlp("") + + assert last_scores["rel"]["micro"]["f"] >= 0.4 + + def test_optimizer(): net = torch.nn.Linear(10, 10) optim = ScheduledOptimizer(