diff --git a/docs/source/api/bluesearch.k8s.ner.rst b/docs/source/api/bluesearch.k8s.ner.rst new file mode 100644 index 000000000..eadae999c --- /dev/null +++ b/docs/source/api/bluesearch.k8s.ner.rst @@ -0,0 +1,7 @@ +bluesearch.k8s.ner module +========================= + +.. automodule:: bluesearch.k8s.ner + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/api/bluesearch.k8s.rst b/docs/source/api/bluesearch.k8s.rst index a76de5009..25fbac0e3 100644 --- a/docs/source/api/bluesearch.k8s.rst +++ b/docs/source/api/bluesearch.k8s.rst @@ -9,6 +9,7 @@ Submodules bluesearch.k8s.connect bluesearch.k8s.create_indices + bluesearch.k8s.ner Module contents --------------- diff --git a/src/bluesearch/k8s/create_indices.py b/src/bluesearch/k8s/create_indices.py index 00db98efe..ffa4f7c54 100644 --- a/src/bluesearch/k8s/create_indices.py +++ b/src/bluesearch/k8s/create_indices.py @@ -26,7 +26,7 @@ SETTINGS = {"number_of_shards": 2, "number_of_replicas": 1} -MAPPINGS_ARTICLES = { +MAPPINGS_ARTICLES: dict[str, Any] = { "dynamic": "strict", "properties": { "article_id": {"type": "keyword"}, @@ -52,6 +52,12 @@ "section_name": {"type": "keyword"}, "paragraph_id": {"type": "short"}, "text": {"type": "text"}, + "ner_ml_json_v2": {"type": "flattened"}, + "ner_ml_version": {"type": "keyword"}, + "ner_ruler_json_v2": {"type": "flattened"}, + "ner_ruler_version": {"type": "keyword"}, + "re": {"type": "flattened"}, + "re_version": {"type": "keyword"}, "is_bad": {"type": "boolean"}, "embedding": { "type": "dense_vector", @@ -119,3 +125,23 @@ def remove_index(client: Elasticsearch, index: str | list[str]) -> None: except Exception as err: print("Elasticsearch add_index ERROR:", err) + + +def update_index_mapping( + client: Elasticsearch, + index: str, + settings: dict[str, Any] | None = None, + properties: dict[str, Any] | None = None, +) -> None: + """Update the index with a new mapping and settings.""" + if index not in client.indices.get_alias().keys(): + raise RuntimeError("Index not in ES") + + try: + if settings: + client.indices.put_settings(index=index, settings=settings) + if properties: + client.indices.put_mapping(index=index, properties=properties) + logger.info(f"Index {index} updated successfully") + except Exception as err: + print("Elasticsearch add_index ERROR:", err) diff --git a/src/bluesearch/k8s/ner.py b/src/bluesearch/k8s/ner.py new file mode 100644 index 000000000..959593197 --- /dev/null +++ b/src/bluesearch/k8s/ner.py @@ -0,0 +1,333 @@ +# Blue Brain Search is a text mining toolbox focused on scientific use cases. +# +# Copyright (C) 2020 Blue Brain Project, EPFL. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with this program. If not, see . +"""Perform Name Entity Recognition (NER) on paragraphs.""" +from __future__ import annotations + +import logging +import os +import time +from datetime import datetime +from multiprocessing import Pool +from typing import Any + +import elasticsearch +import numpy as np +import pandas as pd +import requests +import tqdm +from dotenv import load_dotenv +from elasticsearch.helpers import scan + +from bluesearch.k8s import connect + +load_dotenv() + +logger = logging.getLogger(__name__) + + +def run( + client: elasticsearch.Elasticsearch, + version: str, + index: str = "paragraphs", + ner_method: str = "ml", + force: bool = False, + n_threads: int = 4, + run_async: bool = True, +) -> None: + """Run the NER pipeline on the paragraphs in the database. + + Parameters + ---------- + client + Elasticsearch client. + version + Version of the NER pipeline. + index + Name of the ES index. + ner_method + Method to use to perform NER. + force + If True, force the NER to be performed even in all paragraphs. + n_threads + Number of threads to use. + run_async + If True, run the NER asynchronously. + """ + # get NER method function and url + if ner_method == "ml": + url = os.environ["BENTOML_NER_ML_URL"] + elif ner_method == "ruler": + url = os.environ["BENTOML_NER_RULER_URL"] + else: + raise ValueError("The ner_method should be either 'ml' or 'ruler'.") + + # get paragraphs without NER unless force is True + if force: + query: dict[str, Any] = {"match_all": {}} + else: + query = {"bool": {"must_not": {"term": {f"ner_{ner_method}_version": version}}}} + paragraph_count = client.options(request_timeout=30).count( + index=index, query=query + )["count"] + logger.info( + f"There are {paragraph_count} paragraphs without NER {ner_method} results." + ) + + # performs NER for all the documents + progress = tqdm.tqdm( + total=paragraph_count, + position=0, + unit=" Paragraphs", + desc="Updating NER", + ) + if run_async: + # start a pool of workers + pool = Pool(processes=n_threads) + open_threads = [] + for hit in scan(client, query={"query": query}, index=index, scroll="12h"): + # add a new thread to the pool + res = pool.apply_async( + run_ner_model_remote, + args=( + hit, + url, + ner_method, + index, + version, + ), + ) + open_threads.append(res) + progress.update(1) + # check if any thread is done + open_threads = [thr for thr in open_threads if not thr.ready()] + # wait if too many threads are running + while len(open_threads) > n_threads: + time.sleep(0.1) + open_threads = [thr for thr in open_threads if not thr.ready()] + # wait for all threads to finish + pool.close() + pool.join() + else: + for hit in scan(client, query={"query": query}, index=index, scroll="12h"): + run_ner_model_remote( + hit=hit, + url=url, + ner_method=ner_method, + index=index, + version=version, + client=client, + ) + progress.update(1) + + progress.close() + + +def run_ner_model_remote( + hit: dict[str, Any], + url: str, + ner_method: str, + index: str | None = None, + version: str | None = None, + client: elasticsearch.Elasticsearch | None = None, +) -> list[dict[str, Any]] | None: + """Perform NER on a paragraph using a remote server. + + Parameters + ---------- + hit + Elasticsearch hit. + url + URL of the NER server. + ner_method + Method to use to perform NER. + index + Name of the ES index. + version + Version of the NER pipeline. + """ + if client is None and index is None and version is None: + logger.info("Running NER in inference mode only.") + elif client is None and index is not None and version is not None: + client = connect.connect() + elif client is None and (index is not None or version is not None): + raise ValueError("Index and version should be both None or not None.") + + url = "http://" + url + "/predict" + + response = requests.post( + url, + headers={"accept": "application/json", "Content-Type": "text/plain"}, + data=hit["_source"]["text"].encode("utf-8"), + ) + + if not response.status_code == 200: + raise ValueError("Error in the request") + + results = response.json() + out = [] + if results: + for res in results: + row = {} + row["entity_type"] = res["entity_group"] + row["entity"] = res["word"] + row["start"] = res["start"] + row["end"] = res["end"] + row["score"] = 0 if ner_method == "ruler" else res["score"] + row["source"] = ner_method + out.append(row) + else: + # if no entity is found, return an empty row, + # necessary for ES to find the field in the document + row = {} + row["entity_type"] = "Empty" + row["entity"] = "" + row["start"] = 0 + row["end"] = 0 + row["score"] = 0 + row["source"] = ner_method + out.append(row) + + if client is not None and index is not None and version is not None: + # update the NER field in the document + client.update( + index=index, doc={f"ner_{ner_method}_json_v2": out}, id=hit["_id"] + ) + # update the version of the NER + client.update( + index=index, doc={f"ner_{ner_method}_version": version}, id=hit["_id"] + ) + return None + else: + return out + + +def handle_conflicts(results_paragraph: list[dict]) -> list[dict]: + """Handle conflicts between the NER pipeline and the entity ruler. + + Parameters + ---------- + results_paragraph + List of entities found by the NER pipeline. + """ + # if there is only one entity, it will be kept + if len(results_paragraph) <= 1: + return results_paragraph + + temp = sorted( + results_paragraph, + key=lambda x: (-(x["end"] - x["start"]), x["source"]), + ) + + results_cleaned: list[dict] = [] + + array = np.zeros(max([x["end"] for x in temp])) + for res in temp: + add_one = 1 if res["entity"][0] == " " else 0 + sub_one = 1 if res["entity"][-1] == " " else 0 + if len(results_cleaned) == 0: + results_cleaned.append(res) + array[res["start"] + add_one : res["end"] - sub_one] = 1 + else: + if array[res["start"] + add_one : res["end"] - sub_one].sum() == 0: + results_cleaned.append(res) + array[res["start"] + add_one : res["end"] - sub_one] = 1 + + results_cleaned.sort(key=lambda x: x["start"]) + return results_cleaned + + +def retrieve_csv( + client: elasticsearch.Elasticsearch, + index: str = "paragraphs", + ner_method: str = "both", + output_path: str = "./", +) -> None: + """Retrieve the NER results from the database and save them in a csv file. + + Parameters + ---------- + client + Elasticsearch client. + index + Name of the ES index. + ner_method + Method to use to perform NER. + output_path + Path where one wants to save the csv file. + """ + now = datetime.now().strftime("%d_%m_%Y_%H_%M") + + if ner_method == "both": + query: dict[str, dict[str, Any]] = { + "bool": { + "filter": [ + {"exists": {"field": "ner_ml_json_v2"}}, + {"exists": {"field": "ner_ruler_json_v2"}}, + ] + } + } + elif ner_method in ["ml", "ruler"]: + query = {"exists": {"field": f"ner_{ner_method}_json_v2"}} + else: + raise ValueError("The ner_method should be either 'ml', 'ruler' or 'both'.") + + paragraph_count = client.count(index=index, query=query)["count"] + logger.info( + f"There are {paragraph_count} paragraphs with NER {ner_method} results." + ) + + progress = tqdm.tqdm( + total=paragraph_count, + position=0, + unit=" Paragraphs", + desc="Retrieving NER", + ) + results = [] + for hit in scan(client, query={"query": query}, index=index, scroll="12h"): + if ner_method == "both": + results_paragraph = [ + *hit["_source"]["ner_ml_json_v2"], + *hit["_source"]["ner_ruler_json_v2"], + ] + results_paragraph = handle_conflicts(results_paragraph) + else: + results_paragraph = hit["_source"][f"ner_{ner_method}_json_v2"] + + for res in results_paragraph: + row = {} + row["entity_type"] = res["entity_type"] + row["entity"] = res["entity"] + row["start"] = res["start"] + row["end"] = res["end"] + row["source"] = res["source"] + row["paragraph_id"] = hit["_id"] + row["article_id"] = hit["_source"]["article_id"] + results.append(row) + + progress.update(1) + logger.info(f"Retrieved NER for paragraph {hit['_id']}, progress: {progress.n}") + + progress.close() + + df = pd.DataFrame(results) + df.to_csv(f"{output_path}/ner_es_results{ner_method}_{now}.csv", index=False) + + +if __name__ == "__main__": + logging.basicConfig(format="%(levelname)s:%(message)s", level=logging.WARNING) + client = connect.connect() + run(client, version="v2") diff --git a/tests/conftest.py b/tests/conftest.py index 43d0f97ba..a3cac711a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -518,5 +518,5 @@ def get_es_client(monkeypatch): if client is not None: for index in client.indices.get_alias().keys(): - if index in ["articles", "paragraphs", "test_index"]: + if index in ["test_articles", "test_paragraphs", "test_index"]: remove_index(client, index) diff --git a/tests/unit/k8s/test_create_indices.py b/tests/unit/k8s/test_create_indices.py index 75be3178d..630815212 100644 --- a/tests/unit/k8s/test_create_indices.py +++ b/tests/unit/k8s/test_create_indices.py @@ -1,6 +1,12 @@ import pytest -from bluesearch.k8s.create_indices import add_index, remove_index +from bluesearch.k8s.create_indices import ( + MAPPINGS_ARTICLES, + SETTINGS, + add_index, + remove_index, + update_index_mapping, +) def test_create_and_remove_index(get_es_client): @@ -13,3 +19,43 @@ def test_create_and_remove_index(get_es_client): add_index(client, index) remove_index(client, index) + + +def test_update_index_mapping(get_es_client): + client = get_es_client + + if client is None: + pytest.skip("Elastic search is not available") + + index = "test_index" + + add_index(client, index, settings=SETTINGS, mappings=MAPPINGS_ARTICLES) + + index_settings = client.indices.get_settings(index=index) + assert index_settings[index]["settings"]["index"]["number_of_replicas"] == str( + SETTINGS["number_of_replicas"] + ) + assert ( + client.indices.get_mapping(index=index)[index]["mappings"] == MAPPINGS_ARTICLES + ) + + fake_settings = {"number_of_replicas": 2} + fake_properties = {"x": {"type": "text"}} + update_index_mapping( + client, + index, + settings=fake_settings, + properties=fake_properties, + ) + + index_settings = client.indices.get_settings(index=index) + assert index_settings[index]["settings"]["index"]["number_of_replicas"] == "2" + + NEW_MAPPINGS_ARTICLES = MAPPINGS_ARTICLES.copy() + NEW_MAPPINGS_ARTICLES["properties"]["x"] = {"type": "text"} + assert ( + client.indices.get_mapping(index=index)[index]["mappings"] + == NEW_MAPPINGS_ARTICLES + ) + + remove_index(client, index) diff --git a/tests/unit/k8s/test_ner.py b/tests/unit/k8s/test_ner.py new file mode 100644 index 000000000..db05c78ac --- /dev/null +++ b/tests/unit/k8s/test_ner.py @@ -0,0 +1,501 @@ +import pytest +import responses + +from bluesearch.k8s.create_indices import MAPPINGS_PARAGRAPHS, add_index, remove_index +from bluesearch.k8s.ner import handle_conflicts, run, run_ner_model_remote + + +@pytest.fixture() +def model_response(): + url = "fake_url" + expected_url = "http://" + url + "/predict" + expected_response = [ + { + "entity_group": "ORGANISM", + "score": 0.9439833760261536, + "word": "cat", + "start": 11, + "end": 14, + }, + { + "entity_group": "ORGANISM", + "score": 0.9975798726081848, + "word": "mouse", + "start": 21, + "end": 26, + }, + ] + + responses.add( + responses.POST, + expected_url, + headers={"accept": "application/json", "Content-Type": "text/plain"}, + json=expected_response, + ) + + responses.add( + responses.POST, + expected_url, + headers={"accept": "application/json", "Content-Type": "text/plain"}, + json=expected_response, + ) + + +@responses.activate +def test_run_ner_model_remote(model_response): + url = "fake_url" + hit = {"_source": {"text": "There is a cat and a mouse in the house."}} + + out = run_ner_model_remote(hit, url, ner_method="ml") + assert isinstance(out, list) + assert len(out) == 2 + + assert out[0]["source"] == "ml" + assert out[0]["score"] == 0.9439833760261536 + assert out[1]["score"] == 0.9975798726081848 + assert out[0]["entity_type"] == "ORGANISM" + assert out[0]["entity"] == "cat" + assert out[0]["start"] == 11 + assert out[0]["end"] == 14 + + out = run_ner_model_remote(hit, url, ner_method="ruler") + assert isinstance(out, list) + assert out[0]["score"] == 0 + assert out[1]["score"] == 0 + + +@responses.activate +def test_run(monkeypatch, get_es_client, model_response): + + BENTOML_NER_ML_URL = "fake_url" + BENTOML_NER_RULER_URL = "fake_url" + + monkeypatch.setenv("BENTOML_NER_ML_URL", BENTOML_NER_ML_URL) + monkeypatch.setenv("BENTOML_NER_RULER_URL", BENTOML_NER_RULER_URL) + + client = get_es_client + + if client is None: + pytest.skip("Elastic search is not available") + + index = "test_index" + add_index(client, index, mappings=MAPPINGS_PARAGRAPHS) + + fake_data = [ + { + "article_id": "1", + "paragraph_id": "1", + "text": "There is a cat and a mouse in the house.", + }, + { + "article_id": "1", + "paragraph_id": "2", + "text": "There is a cat and a mouse in the house.", + }, + { + "article_id": "2", + "paragraph_id": "3", + "text": "There is a cat and a mouse in the house.", + }, + ] + + for fd in fake_data: + client.index(index=index, document=fd, id=fd["paragraph_id"]) + + client.indices.refresh(index=index) + run(client, "v1", index=index, run_async=False) + client.indices.refresh(index=index) + + # check that the results are in the database + query = {"bool": {"must": {"term": {"ner_ml_version": "v1"}}}} + paragraph_count = client.count(index=index, query=query)["count"] + assert paragraph_count == 3 + + # check that the results are correct + for fd in fake_data: + res = client.get(index=index, id=fd["paragraph_id"]) + assert res["_source"]["ner_ml_json_v2"][0]["entity"] == "cat" + assert res["_source"]["ner_ml_json_v2"][0]["score"] == 0.9439833760261536 + assert res["_source"]["ner_ml_json_v2"][0]["entity_type"] == "ORGANISM" + assert res["_source"]["ner_ml_json_v2"][0]["start"] == 11 + assert res["_source"]["ner_ml_json_v2"][0]["end"] == 14 + assert res["_source"]["ner_ml_json_v2"][0]["source"] == "ml" + + assert res["_source"]["ner_ml_json_v2"][1]["entity"] == "mouse" + assert res["_source"]["ner_ml_json_v2"][1]["score"] == 0.9975798726081848 + assert res["_source"]["ner_ml_json_v2"][1]["entity_type"] == "ORGANISM" + assert res["_source"]["ner_ml_json_v2"][1]["start"] == 21 + assert res["_source"]["ner_ml_json_v2"][1]["end"] == 26 + assert res["_source"]["ner_ml_json_v2"][1]["source"] == "ml" + + assert res["_source"]["ner_ml_version"] == "v1" + + # check that all paragraphs have been updated + query = {"bool": {"must_not": {"term": {"ner_ml_version": "v1"}}}} + paragraph_count = client.count(index=index, query=query)["count"] + assert paragraph_count == 0 + + remove_index(client, index) + + +@pytest.mark.parametrize( + ("raw_ents", "cleaned_ents"), + [ + pytest.param([], [], id="empty list"), + pytest.param( + [ + { + "start": 1, + "end": 5, + "entity": "entity", + }, + ], + [ + { + "start": 1, + "end": 5, + "entity": "entity", + }, + ], + id="one element", + ), + pytest.param( + [ + { + "start": 1, + "end": 5, + "entity": "entity", + "source": "RULES", + }, + { + "start": 10, + "end": 15, + "entity": "entity", + "source": "ML", + }, + ], + [ + { + "start": 1, + "end": 5, + "entity": "entity", + "source": "RULES", + }, + { + "start": 10, + "end": 15, + "entity": "entity", + "source": "ML", + }, + ], + id="no overlap", + ), + pytest.param( + [ + { + "start": 1, + "end": 5, + "entity": "entity", + "source": "RULES", + }, + { + "start": 1, + "end": 5, + "entity": "entity", + "source": "ML", + }, + ], + [ + { + "start": 1, + "end": 5, + "entity": "entity", + "source": "ML", + }, + ], + id="perfect overlap", + ), + pytest.param( + [ + { + "start": 1, + "end": 5, + "entity": "entity", + "source": "RULES", + }, + { + "start": 2, + "end": 20, + "entity": "entity", + "source": "ML", + }, + ], + [ + { + "start": 2, + "end": 20, + "entity": "entity", + "source": "ML", + }, + ], + id="overlap - ML longer", + ), + pytest.param( + [ + { + "start": 1, + "end": 50, + "entity": "entity", + "source": "RULES", + }, + { + "start": 25, + "end": 60, + "entity": "entity", + "source": "ML", + }, + ], + [ + { + "start": 1, + "end": 50, + "entity": "entity", + "source": "RULES", + }, + ], + id="overlap - RULES longer", + ), + pytest.param( + [ + { + "start": 1, + "end": 50, + "entity": "entity", + "source": "RULES", + }, + { + "start": 25, + "end": 40, + "entity": "entity", + "source": "ML", + }, + ], + [ + { + "start": 1, + "end": 50, + "entity": "entity", + "source": "RULES", + }, + ], + id="overlap - ML subset of RULES", + ), + pytest.param( + [ + { + "start": 4, + "end": 24, + "entity": "entity", + "source": "RULES", + }, + { + "start": 2, + "end": 40, + "entity": "entity", + "source": "ML", + }, + ], + [ + { + "start": 2, + "end": 40, + "entity": "entity", + "source": "ML", + }, + ], + id="overlap - RULES subset of ML", + ), + pytest.param( + [ + { + "start": 10, + "end": 30, + "entity": "entity", + "source": "RULES", + }, + { + "start": 20, + "end": 40, + "entity": "entity", + "source": "ML", + }, + ], + [ + { + "start": 20, + "end": 40, + "entity": "entity", + "source": "ML", + }, + ], + id="overlap - same length", + ), + pytest.param( + [ + { + "start": 10, + "end": 30, + "entity": "entity", + "source": "RULES", + }, + { + "start": 15, + "end": 20, + "entity": "entity", + "source": "ML", + }, + { + "start": 23, + "end": 34, + "entity": "entity", + "source": "ML", + }, + { + "start": 31, + "end": 33, + "entity": "entity", + "source": "RULES", + }, + ], + [ + { + "start": 10, + "end": 30, + "entity": "entity", + "source": "RULES", + }, + { + "start": 31, + "end": 33, + "entity": "entity", + "source": "RULES", + }, + ], + id="more entries - 1", + ), + pytest.param( + [ + { + "start": 10, + "end": 30, + "entity": "entity", + "source": "RULES", + }, + { + "start": 20, + "end": 50, + "entity": "entity", + "source": "ML", + }, + { + "start": 35, + "end": 100, + "entity": "entity", + "source": "RULES", + }, + ], + [ + { + "start": 10, + "end": 30, + "entity": "entity", + "source": "RULES", + }, + { + "start": 35, + "end": 100, + "entity": "entity", + "source": "RULES", + }, + ], + id="more entries - 2", + ), + pytest.param( + [ + { + "start": 10, + "end": 12, + "entity": "entity", + "source": "RULES", + }, + { + "start": 10, + "end": 12, + "entity": "entity", + "source": "ML", + }, + { + "start": 35, + "end": 100, + "entity": "entity", + "source": "RULES", + }, + ], + [ + { + "start": 10, + "end": 12, + "entity": "entity", + "source": "ML", + }, + { + "start": 35, + "end": 100, + "entity": "entity", + "source": "RULES", + }, + ], + id="entries with only 2 chars", + ), + pytest.param( + [ + { + "start": 10, + "end": 12, + "entity": "entity", + "source": "RULES", + }, + { + "start": 10, + "end": 12, + "entity": "entity", + "source": "ML", + }, + { + "start": 12, + "end": 15, + "entity": " entity", + "source": "RULES", + }, + ], + [ + { + "start": 10, + "end": 12, + "entity": "entity", + "source": "ML", + }, + { + "start": 12, + "end": 15, + "entity": " entity", + "source": "RULES", + }, + ], + id="overlap with whitespace", + ), + ], +) +def test_handle_conflicts(raw_ents, cleaned_ents): + """Test handle_conflicts function.""" + assert cleaned_ents == handle_conflicts(raw_ents) diff --git a/tox.ini b/tox.ini index 8264317cc..29ee10dd6 100644 --- a/tox.ini +++ b/tox.ini @@ -157,6 +157,7 @@ filterwarnings = ignore::elastic_transport.SecurityWarning ignore::urllib3.exceptions.InsecureRequestWarning ignore::luigi.parameter.UnconsumedParameterWarning: + ignore::UserWarning:spacy.language: addopts = --cov --cov-config=tox.ini