From de25b3cbb26cd39066649ad40128f7b430157d19 Mon Sep 17 00:00:00 2001 From: Diogo Santos Date: Wed, 26 Oct 2022 10:40:01 +0200 Subject: [PATCH 01/32] add ner k8s --- src/bluesearch/k8s/create_indices.py | 22 +++ src/bluesearch/k8s/ner.py | 204 +++++++++++++++++++++++++++ 2 files changed, 226 insertions(+) create mode 100644 src/bluesearch/k8s/ner.py diff --git a/src/bluesearch/k8s/create_indices.py b/src/bluesearch/k8s/create_indices.py index 00db98efe..34d44b121 100644 --- a/src/bluesearch/k8s/create_indices.py +++ b/src/bluesearch/k8s/create_indices.py @@ -52,6 +52,8 @@ "section_name": {"type": "keyword"}, "paragraph_id": {"type": "short"}, "text": {"type": "text"}, + "ner_ml": {"type": "flattened"}, + "ner_ruler": {"type": "flattened"}, "is_bad": {"type": "boolean"}, "embedding": { "type": "dense_vector", @@ -119,3 +121,23 @@ def remove_index(client: Elasticsearch, index: str | list[str]) -> None: except Exception as err: print("Elasticsearch add_index ERROR:", err) + + +def update_index_mapping( + client: Elasticsearch, + index: str, + settings: dict[str, Any] | None = None, + mappings: dict[str, Any] | None = None, +) -> None: + """Update the index with a new mapping and settings.""" + if index not in client.indices.get_alias().keys(): + raise RuntimeError("Index not in ES") + + try: + if settings: + client.indices.put_settings(index=index, body=settings) + if mappings: + client.indices.put_mapping(index=index, body=mappings) + logger.info(f"Index {index} updated successfully") + except Exception as err: + print("Elasticsearch add_index ERROR:", err) \ No newline at end of file diff --git a/src/bluesearch/k8s/ner.py b/src/bluesearch/k8s/ner.py new file mode 100644 index 000000000..d561e4f23 --- /dev/null +++ b/src/bluesearch/k8s/ner.py @@ -0,0 +1,204 @@ +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with this program. If not, see . +"""Perform Name Entity Recognition (NER) on a paragraph.""" +import logging +import os +from typing import Any + +import elasticsearch +import numpy as np +import requests +import tqdm +from dotenv import load_dotenv +from elasticsearch.helpers import scan + +load_dotenv() + +logger = logging.getLogger(__name__) + + +def run( + client: elasticsearch.Elasticsearch, + index: str = "paragraphs", + ner_method: str = "both", + force: bool = False, +) -> None: + """Runs the NER pipeline on the paragraphs in the database. + + Parameters + ---------- + client + Elasticsearch client. + index + Name of the ES index. + ner_method + Method to use to perform NER. + """ + + # get paragraphs without NER unless force is True + if force: + query: dict[str, Any] = {"match_all": {}} + else: + query = {"bool": {"must_not": {"exists": {"field": "ner"}}}} + paragraph_count = client.count(index=index, query=query)["count"] + logger.info("There are {paragraph_count} paragraphs without embeddings") + + # creates NER for all the documents + progress = tqdm.tqdm( + total=paragraph_count, + position=0, + unit=" Paragraphs", + desc="Updating NER", + ) + for hit in scan(client, query={"query": query}, index=index): + if ner_method == "both": + results_ml = run_ml_ner( + hit["_source"]["text"], + os.environ["BENTOML_NER_ML_URL"] + ) + results_ruller = run_ruler_ner( + hit["_source"]["text"], + os.environ["BENTOML_NER_RULER_URL"] + ) + + client.update(index=index, doc={"ner_ml": results_ml}, id=hit["_id"]) + client.update(index=index, doc={"ner_ruler": results_ruller}, id=hit["_id"]) + + elif ner_method == "ml": + results = run_ml_ner( + hit["_source"]["text"], + os.environ["BENTOML_NER_RULER_URL"] + ) + client.update(index=index, doc={"ner_ml": results}, id=hit["_id"]) + elif ner_method == "ruler": + results = run_ruler_ner( + hit["_source"]["text"], + os.environ["BENTOML_NER_RULER_URL"] + ) + client.update(index=index, doc={"ner_ruler": results}, id=hit["_id"]) + else: + raise ValueError(f"Unknown NER method: {ner_method}") + + progress.update(1) + logger.info(f"Updated NER for paragraph {hit['_id']}, progress: {progress.n}") + + progress.close() + + +def run_ml_ner(text: str, url: str) -> list[dict]: + """Runs the NER pipeline on the paragraphs in the database. + + Parameters + ---------- + text + Text to perform NER on. + article_id + Id of the article. + paragraph_id + Id of the paragraph. + ml_model + Name of the ML model to use. + """ + url = "http://" + url + "/predict" + + response = requests.post( + url, + headers={"accept": "application/json", "Content-Type": "text/plain"}, + data=text, + ) + + if not response.status_code == 200: + raise ValueError("Error in the request") + + results = response.json() + + out = [] + for res in results: + row = {} + row["entity"] = res["entity_group"] + row["word"] = res["word"] + row["start"] = res["start"] + row["end"] = res["end"] + row["source"] = "ML" + out.append(row) + + return out + + +def run_ruler_ner( + text: str, url: str +) -> list[dict]: + """Runs the NER pipeline on the paragraphs in the database. + + Parameters + ---------- + text + Text to perform NER on. + article_id + Id of the article. + paragraph_id + Id of the paragraph. + ruler_model + Name of the entity ruler model to use. + """ + url = "http://" + url + "/predict" + + response = requests.post( + url, + headers={"accept": "application/json", "Content-Type": "text/plain"}, + data=text, + ) + + if not response.status_code == 200: + raise ValueError("Error in the request") + + results = response.json() + + out = [] + for res in results: + row = {} + row["entity"] = res["entity"] + row["word"] = res["word"] + row["start"] = res["start"] + row["end"] = res["end"] + row["source"] = "RULES" + out.append(row) + + return out + + +def handle_conflicts(results_paragraph: list[dict]) -> list[dict]: + """Handle conflicts between the NER pipeline and the entity ruler.""" + # if there is only one entity, it will be kept + if len(results_paragraph) <= 1: + return results_paragraph + + temp = sorted( + results_paragraph, + key=lambda x: (-(x["end"] - x["start"]), x["source"]), + ) + + results_cleaned: list[dict] = [] + + array = np.zeros(max([x["end"] for x in temp])) + for res in temp: + add_one = 1 if res["word"][0] == " " else 0 + sub_one = 1 if res["word"][-1] == " " else 0 + if len(results_cleaned) == 0: + results_cleaned.append(res) + array[res["start"] + add_one : res["end"] - sub_one] = 1 + else: + if array[res["start"] + add_one : res["end"] - sub_one].sum() == 0: + results_cleaned.append(res) + array[res["start"] + add_one : res["end"] - sub_one] = 1 + + results_cleaned.sort(key=lambda x: x["start"]) + return results_cleaned From 10c5bb4cb349596ad62582b944c0a8c75b3182c3 Mon Sep 17 00:00:00 2001 From: Diogo Santos Date: Wed, 26 Oct 2022 14:23:04 +0200 Subject: [PATCH 02/32] updated based on Emilie PR --- src/bluesearch/k8s/create_indices.py | 11 ++- src/bluesearch/k8s/ner.py | 133 ++++++++++----------------- 2 files changed, 56 insertions(+), 88 deletions(-) diff --git a/src/bluesearch/k8s/create_indices.py b/src/bluesearch/k8s/create_indices.py index 34d44b121..28ff0a5f2 100644 --- a/src/bluesearch/k8s/create_indices.py +++ b/src/bluesearch/k8s/create_indices.py @@ -54,6 +54,7 @@ "text": {"type": "text"}, "ner_ml": {"type": "flattened"}, "ner_ruler": {"type": "flattened"}, + "re": {"type": "flattened"}, "is_bad": {"type": "boolean"}, "embedding": { "type": "dense_vector", @@ -127,7 +128,7 @@ def update_index_mapping( client: Elasticsearch, index: str, settings: dict[str, Any] | None = None, - mappings: dict[str, Any] | None = None, + properties: dict[str, Any] | None = None, ) -> None: """Update the index with a new mapping and settings.""" if index not in client.indices.get_alias().keys(): @@ -135,9 +136,9 @@ def update_index_mapping( try: if settings: - client.indices.put_settings(index=index, body=settings) - if mappings: - client.indices.put_mapping(index=index, body=mappings) + client.indices.put_settings(index=index, settings=settings) + if properties: + client.indices.put_mapping(index=index, properties=properties) logger.info(f"Index {index} updated successfully") except Exception as err: - print("Elasticsearch add_index ERROR:", err) \ No newline at end of file + print("Elasticsearch add_index ERROR:", err) diff --git a/src/bluesearch/k8s/ner.py b/src/bluesearch/k8s/ner.py index d561e4f23..99bde36d4 100644 --- a/src/bluesearch/k8s/ner.py +++ b/src/bluesearch/k8s/ner.py @@ -1,3 +1,9 @@ +# Blue Brain Search is a text mining toolbox focused on scientific use cases. +# +# Copyright (C) 2020 Blue Brain Project, EPFL. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. # @@ -8,7 +14,7 @@ # # You should have received a copy of the GNU Lesser General Public License # along with this program. If not, see . -"""Perform Name Entity Recognition (NER) on a paragraph.""" +"""Perform Name Entity Recognition (NER) on paragraphs.""" import logging import os from typing import Any @@ -28,10 +34,10 @@ def run( client: elasticsearch.Elasticsearch, index: str = "paragraphs", - ner_method: str = "both", + ner_method: str = "ml", force: bool = False, ) -> None: - """Runs the NER pipeline on the paragraphs in the database. + """Run the NER pipeline on the paragraphs in the database. Parameters ---------- @@ -41,17 +47,30 @@ def run( Name of the ES index. ner_method Method to use to perform NER. + force + If True, force the NER to be performed even in all paragraphs. """ + # get NER method function and url + if ner_method == "ml": + url = os.environ["BENTOML_NER_ML_URL"] + elif ner_method == "ruler": + url = os.environ["BENTOML_NER_RULER_URL"] + else: + raise ValueError("The ner_method should be either 'ml' or 'ruler'.") # get paragraphs without NER unless force is True if force: query: dict[str, Any] = {"match_all": {}} else: - query = {"bool": {"must_not": {"exists": {"field": "ner"}}}} + query = { + "bool": {"must_not": {"exists": {"field": f"ner_{ner_method}_json_v2"}}} + } paragraph_count = client.count(index=index, query=query)["count"] - logger.info("There are {paragraph_count} paragraphs without embeddings") + logger.info( + f"There are {paragraph_count} paragraphs without NER {ner_method} results." + ) - # creates NER for all the documents + # performs NER for all the documents progress = tqdm.tqdm( total=paragraph_count, position=0, @@ -59,60 +78,50 @@ def run( desc="Updating NER", ) for hit in scan(client, query={"query": query}, index=index): - if ner_method == "both": - results_ml = run_ml_ner( + try: + results = run_ner_remote( hit["_source"]["text"], - os.environ["BENTOML_NER_ML_URL"] + url, + ner_method, ) - results_ruller = run_ruler_ner( - hit["_source"]["text"], - os.environ["BENTOML_NER_RULER_URL"] + client.update( + index=index, doc={f"ner_{ner_method}_json_v2": results}, id=hit["_id"] ) - client.update(index=index, doc={"ner_ml": results_ml}, id=hit["_id"]) - client.update(index=index, doc={"ner_ruler": results_ruller}, id=hit["_id"]) - - elif ner_method == "ml": - results = run_ml_ner( - hit["_source"]["text"], - os.environ["BENTOML_NER_RULER_URL"] + progress.update(1) + logger.info( + f"Updated NER for paragraph {hit['_id']}, progress: {progress.n}" ) - client.update(index=index, doc={"ner_ml": results}, id=hit["_id"]) - elif ner_method == "ruler": - results = run_ruler_ner( - hit["_source"]["text"], - os.environ["BENTOML_NER_RULER_URL"] - ) - client.update(index=index, doc={"ner_ruler": results}, id=hit["_id"]) - else: - raise ValueError(f"Unknown NER method: {ner_method}") - - progress.update(1) - logger.info(f"Updated NER for paragraph {hit['_id']}, progress: {progress.n}") + except Exception as e: + print(e) + logger.error(f"Error in paragraph {hit['_id']}, progress: {progress.n}") progress.close() -def run_ml_ner(text: str, url: str) -> list[dict]: - """Runs the NER pipeline on the paragraphs in the database. +def run_ner_remote(text: str, url: str, source: str) -> list[dict]: + """Run NER on the remote server for a specific paragraph text. Parameters ---------- text Text to perform NER on. - article_id - Id of the article. - paragraph_id - Id of the paragraph. - ml_model - Name of the ML model to use. + url + URL of the remote server. + source + Source model of the NER results. + + Returns + ------- + results + List of dictionaries with the NER results. """ url = "http://" + url + "/predict" response = requests.post( url, headers={"accept": "application/json", "Content-Type": "text/plain"}, - data=text, + data=text.encode("utf-8"), ) if not response.status_code == 200: @@ -127,49 +136,7 @@ def run_ml_ner(text: str, url: str) -> list[dict]: row["word"] = res["word"] row["start"] = res["start"] row["end"] = res["end"] - row["source"] = "ML" - out.append(row) - - return out - - -def run_ruler_ner( - text: str, url: str -) -> list[dict]: - """Runs the NER pipeline on the paragraphs in the database. - - Parameters - ---------- - text - Text to perform NER on. - article_id - Id of the article. - paragraph_id - Id of the paragraph. - ruler_model - Name of the entity ruler model to use. - """ - url = "http://" + url + "/predict" - - response = requests.post( - url, - headers={"accept": "application/json", "Content-Type": "text/plain"}, - data=text, - ) - - if not response.status_code == 200: - raise ValueError("Error in the request") - - results = response.json() - - out = [] - for res in results: - row = {} - row["entity"] = res["entity"] - row["word"] = res["word"] - row["start"] = res["start"] - row["end"] = res["end"] - row["source"] = "RULES" + row["source"] = source out.append(row) return out From d9acafc85e85112a1481c3c91b00d51139600103 Mon Sep 17 00:00:00 2001 From: Diogo Santos Date: Wed, 26 Oct 2022 16:28:33 +0200 Subject: [PATCH 03/32] add retrieve csv and initial test --- src/bluesearch/k8s/ner.py | 90 +++++++++++++++++++++++++++++++++++--- tests/unit/k8s/test_ner.py | 15 +++++++ 2 files changed, 100 insertions(+), 5 deletions(-) create mode 100644 tests/unit/k8s/test_ner.py diff --git a/src/bluesearch/k8s/ner.py b/src/bluesearch/k8s/ner.py index 99bde36d4..c928e86b1 100644 --- a/src/bluesearch/k8s/ner.py +++ b/src/bluesearch/k8s/ner.py @@ -25,6 +25,8 @@ import tqdm from dotenv import load_dotenv from elasticsearch.helpers import scan +import pandas as pd +from datetime import datetime load_dotenv() @@ -77,9 +79,9 @@ def run( unit=" Paragraphs", desc="Updating NER", ) - for hit in scan(client, query={"query": query}, index=index): + for hit in scan(client, query={"query": query}, index=index, scroll="12h"): try: - results = run_ner_remote( + results = run_ner_model_remote( hit["_source"]["text"], url, ner_method, @@ -99,7 +101,7 @@ def run( progress.close() -def run_ner_remote(text: str, url: str, source: str) -> list[dict]: +def run_ner_model_remote(text: str, url: str, source: str) -> list[dict]: """Run NER on the remote server for a specific paragraph text. Parameters @@ -132,10 +134,11 @@ def run_ner_remote(text: str, url: str, source: str) -> list[dict]: out = [] for res in results: row = {} - row["entity"] = res["entity_group"] - row["word"] = res["word"] + row["entity_type"] = res["entity_group"] + row["entity"] = res["word"] row["start"] = res["start"] row["end"] = res["end"] + row["score"] = 0 if source == "ruler" else res["score"] row["source"] = source out.append(row) @@ -169,3 +172,80 @@ def handle_conflicts(results_paragraph: list[dict]) -> list[dict]: results_cleaned.sort(key=lambda x: x["start"]) return results_cleaned + + +def retrieve_csv( + client: elasticsearch.Elasticsearch, + index: str = "paragraphs", + ner_method: str = "both", + output_path: str = "./", +) -> None: + """Retrieve the NER results from the database and save them in a csv file. + + Parameters + ---------- + client + Elasticsearch client. + index + Name of the ES index. + ner_method + Method to use to perform NER. + """ + now = datetime.now().strftime('%d_%m_%Y_%H_%M') + + if ner_method == "both": + query = { + "bool": { + "filter": [ + {"exists": {"field": "ner_ml_json_v2"}}, + {"exists": {"field": "ner_ruler_json_v2"}}, + ] + } + } + elif ner_method in ["ml", "ruler"]: + query = {"exists": {"field": f"ner_{ner_method}_json_v2"}} + else: + raise ValueError("The ner_method should be either 'ml', 'ruler' or 'both'.") + + paragraph_count = client.count(index=index, query=query)["count"] + logger.info( + f"There are {paragraph_count} paragraphs with NER {ner_method} results." + ) + + progress = tqdm.tqdm( + total=paragraph_count, + position=0, + unit=" Paragraphs", + desc="Retrieving NER", + ) + results = [] + for hit in scan(client, query={"query": query}, index=index, scroll="12h"): + if ner_method == "both": + results_paragraph = [ + *hit["_source"]["ner_ml_json_v2"], + *hit["_source"]["ner_ruler_json_v2"], + ] + results_paragraph = handle_conflicts(results_paragraph) + else: + results_paragraph = hit["_source"][f"ner_{ner_method}_json_v2"] + + for res in results_paragraph: + row = {} + row["entity_type"] = res["entity_type"] + row["entity"] = res["entity"] + row["start"] = res["start"] + row["end"] = res["end"] + row["source"] = res["source"] + row["paragraph_id"] = hit["_id"] + row["article_id"] = hit["_source"]["article_id"] + results.append(row) + + progress.update(1) + logger.info( + f"Retrieved NER for paragraph {hit['_id']}, progress: {progress.n}" + ) + + progress.close() + + df = pd.DataFrame(results) + df.to_csv(f"{output_path}/ner_es_results{ner_method}_{now}.csv", index=False) diff --git a/tests/unit/k8s/test_ner.py b/tests/unit/k8s/test_ner.py new file mode 100644 index 000000000..4764fdd72 --- /dev/null +++ b/tests/unit/k8s/test_ner.py @@ -0,0 +1,15 @@ +import pytest + +from bluesearch.k8s.create_indices import add_index +from bluesearch.k8s.ner import run, run_ner_model_remote, retrieve_csv + +def test_create_and_remove_index(get_es_client): + client = get_es_client + + if client is None: + pytest.skip("Elastic search is not available") + + index = "test_index" + + add_index(client, index) + From 55816fbc7180492a6511e32cb84775acc46a020b Mon Sep 17 00:00:00 2001 From: Emilie Delattre Date: Wed, 26 Oct 2022 16:57:45 +0200 Subject: [PATCH 04/32] Create test for create_indices module --- tests/unit/k8s/test_create_indices.py | 41 ++++++++++++++++++++++++++- 1 file changed, 40 insertions(+), 1 deletion(-) diff --git a/tests/unit/k8s/test_create_indices.py b/tests/unit/k8s/test_create_indices.py index 75be3178d..e4b359a02 100644 --- a/tests/unit/k8s/test_create_indices.py +++ b/tests/unit/k8s/test_create_indices.py @@ -1,6 +1,12 @@ import pytest -from bluesearch.k8s.create_indices import add_index, remove_index +from bluesearch.k8s.create_indices import ( + SETTINGS, + MAPPINGS_ARTICLES, + add_index, + remove_index, + update_index_mapping, +) def test_create_and_remove_index(get_es_client): @@ -13,3 +19,36 @@ def test_create_and_remove_index(get_es_client): add_index(client, index) remove_index(client, index) + + +def test_update_index_mapping(get_es_client): + client = get_es_client + + if client is None: + pytest.skip("Elastic search is not available") + + index = "test_index" + + add_index(client, index, settings=SETTINGS, mappings=MAPPINGS_ARTICLES) + + index_settings = client.indices.get_settings(index=index) + assert index_settings[index]["settings"]["index"]["number_of_replicas"] == str(SETTINGS["number_of_replicas"]) + assert client.indices.get_mapping(index=index)[index]["mappings"] == MAPPINGS_ARTICLES + + fake_settings = {"number_of_replicas": 2} + fake_properties = {"x": {"type": "text"}} + update_index_mapping( + client, + index, + settings=fake_settings, + properties=fake_properties, + ) + + index_settings = client.indices.get_settings(index=index) + assert index_settings[index]["settings"]["index"]["number_of_replicas"] == "2" + + NEW_MAPPINGS_ARTICLES = MAPPINGS_ARTICLES.copy() + NEW_MAPPINGS_ARTICLES["properties"]["x"] = {"type": "text"} + assert client.indices.get_mapping(index=index)[index]["mappings"] == NEW_MAPPINGS_ARTICLES + + remove_index(client, index) From e1017aa2f97fc2f8c018faf0cadb010ad54598e3 Mon Sep 17 00:00:00 2001 From: Emilie Delattre Date: Wed, 26 Oct 2022 17:19:10 +0200 Subject: [PATCH 05/32] Add test for run_ner_model_remote --- src/bluesearch/k8s/ner.py | 2 ++ tests/unit/k8s/test_ner.py | 49 ++++++++++++++++++++++++++++++++++++-- 2 files changed, 49 insertions(+), 2 deletions(-) diff --git a/src/bluesearch/k8s/ner.py b/src/bluesearch/k8s/ner.py index c928e86b1..56e53f9ad 100644 --- a/src/bluesearch/k8s/ner.py +++ b/src/bluesearch/k8s/ner.py @@ -15,6 +15,8 @@ # You should have received a copy of the GNU Lesser General Public License # along with this program. If not, see . """Perform Name Entity Recognition (NER) on paragraphs.""" +from __future__ import annotations + import logging import os from typing import Any diff --git a/tests/unit/k8s/test_ner.py b/tests/unit/k8s/test_ner.py index 4764fdd72..b3cbb1611 100644 --- a/tests/unit/k8s/test_ner.py +++ b/tests/unit/k8s/test_ner.py @@ -1,15 +1,60 @@ import pytest +import responses from bluesearch.k8s.create_indices import add_index from bluesearch.k8s.ner import run, run_ner_model_remote, retrieve_csv -def test_create_and_remove_index(get_es_client): + +@responses.activate +def test_run_ner_model_remote(get_es_client): client = get_es_client if client is None: pytest.skip("Elastic search is not available") index = "test_index" - add_index(client, index) + url = "fake_url" + expected_url = "http://" + url + "/predict" + text = "There is a cat and a mouse in the house." + expected_response = [ + { + 'entity_group': 'ORGANISM', + 'score': 0.9439833760261536, + 'word': 'cat', + 'start': 11, + 'end': 14 + }, + { + 'entity_group': 'ORGANISM', + 'score': 0.9975798726081848, + 'word': 'mouse', + 'start': 21, + 'end': 26 + } + ] + + responses.add( + responses.POST, + expected_url, + headers={"accept": "application/json", "Content-Type": "text/plain"}, + json=expected_response, + ) + + out = run_ner_model_remote(text, url, source="ml") + assert isinstance(out, list) + assert len(out) == 2 + + assert out[0]["source"] == "ml" + assert out[0]["score"] == 0.9439833760261536 + assert out[1]["score"] == 0.9975798726081848 + assert out[0]["entity_type"] == "ORGANISM" + assert out[0]["entity"] == "cat" + assert out[0]["start"] == 11 + assert out[0]["end"] == 14 + + out = run_ner_model_remote(text, url, source="ruler") + assert out[0]["score"] == 0 + assert out[1]["score"] == 0 + From 93620683594ebab810531e3ad7c8285e9395e943 Mon Sep 17 00:00:00 2001 From: Emilie Delattre Date: Thu, 27 Oct 2022 16:07:31 +0200 Subject: [PATCH 06/32] Fix linting --- src/bluesearch/k8s/ner.py | 12 ++++++------ tests/unit/k8s/test_create_indices.py | 15 +++++++++++---- tests/unit/k8s/test_ner.py | 25 ++++++++++++------------- 3 files changed, 29 insertions(+), 23 deletions(-) diff --git a/src/bluesearch/k8s/ner.py b/src/bluesearch/k8s/ner.py index 56e53f9ad..c4305cf11 100644 --- a/src/bluesearch/k8s/ner.py +++ b/src/bluesearch/k8s/ner.py @@ -19,16 +19,16 @@ import logging import os +from datetime import datetime from typing import Any import elasticsearch import numpy as np +import pandas as pd import requests import tqdm from dotenv import load_dotenv from elasticsearch.helpers import scan -import pandas as pd -from datetime import datetime load_dotenv() @@ -192,8 +192,10 @@ def retrieve_csv( Name of the ES index. ner_method Method to use to perform NER. + output_path + Path where one wants to save the csv file. """ - now = datetime.now().strftime('%d_%m_%Y_%H_%M') + now = datetime.now().strftime("%d_%m_%Y_%H_%M") if ner_method == "both": query = { @@ -243,9 +245,7 @@ def retrieve_csv( results.append(row) progress.update(1) - logger.info( - f"Retrieved NER for paragraph {hit['_id']}, progress: {progress.n}" - ) + logger.info(f"Retrieved NER for paragraph {hit['_id']}, progress: {progress.n}") progress.close() diff --git a/tests/unit/k8s/test_create_indices.py b/tests/unit/k8s/test_create_indices.py index e4b359a02..630815212 100644 --- a/tests/unit/k8s/test_create_indices.py +++ b/tests/unit/k8s/test_create_indices.py @@ -1,8 +1,8 @@ import pytest from bluesearch.k8s.create_indices import ( - SETTINGS, MAPPINGS_ARTICLES, + SETTINGS, add_index, remove_index, update_index_mapping, @@ -32,8 +32,12 @@ def test_update_index_mapping(get_es_client): add_index(client, index, settings=SETTINGS, mappings=MAPPINGS_ARTICLES) index_settings = client.indices.get_settings(index=index) - assert index_settings[index]["settings"]["index"]["number_of_replicas"] == str(SETTINGS["number_of_replicas"]) - assert client.indices.get_mapping(index=index)[index]["mappings"] == MAPPINGS_ARTICLES + assert index_settings[index]["settings"]["index"]["number_of_replicas"] == str( + SETTINGS["number_of_replicas"] + ) + assert ( + client.indices.get_mapping(index=index)[index]["mappings"] == MAPPINGS_ARTICLES + ) fake_settings = {"number_of_replicas": 2} fake_properties = {"x": {"type": "text"}} @@ -49,6 +53,9 @@ def test_update_index_mapping(get_es_client): NEW_MAPPINGS_ARTICLES = MAPPINGS_ARTICLES.copy() NEW_MAPPINGS_ARTICLES["properties"]["x"] = {"type": "text"} - assert client.indices.get_mapping(index=index)[index]["mappings"] == NEW_MAPPINGS_ARTICLES + assert ( + client.indices.get_mapping(index=index)[index]["mappings"] + == NEW_MAPPINGS_ARTICLES + ) remove_index(client, index) diff --git a/tests/unit/k8s/test_ner.py b/tests/unit/k8s/test_ner.py index b3cbb1611..0c0f20cd1 100644 --- a/tests/unit/k8s/test_ner.py +++ b/tests/unit/k8s/test_ner.py @@ -2,7 +2,7 @@ import responses from bluesearch.k8s.create_indices import add_index -from bluesearch.k8s.ner import run, run_ner_model_remote, retrieve_csv +from bluesearch.k8s.ner import run_ner_model_remote @responses.activate @@ -20,19 +20,19 @@ def test_run_ner_model_remote(get_es_client): text = "There is a cat and a mouse in the house." expected_response = [ { - 'entity_group': 'ORGANISM', - 'score': 0.9439833760261536, - 'word': 'cat', - 'start': 11, - 'end': 14 + "entity_group": "ORGANISM", + "score": 0.9439833760261536, + "word": "cat", + "start": 11, + "end": 14, }, { - 'entity_group': 'ORGANISM', - 'score': 0.9975798726081848, - 'word': 'mouse', - 'start': 21, - 'end': 26 - } + "entity_group": "ORGANISM", + "score": 0.9975798726081848, + "word": "mouse", + "start": 21, + "end": 26, + }, ] responses.add( @@ -57,4 +57,3 @@ def test_run_ner_model_remote(get_es_client): out = run_ner_model_remote(text, url, source="ruler") assert out[0]["score"] == 0 assert out[1]["score"] == 0 - From 202f14c29a41e27b8c9b067b8ac05f1b894d7861 Mon Sep 17 00:00:00 2001 From: Emilie Delattre Date: Thu, 27 Oct 2022 16:11:17 +0200 Subject: [PATCH 07/32] Update docs --- docs/source/api/bluesearch.k8s.ner.rst | 7 +++++++ docs/source/api/bluesearch.k8s.rst | 1 + 2 files changed, 8 insertions(+) create mode 100644 docs/source/api/bluesearch.k8s.ner.rst diff --git a/docs/source/api/bluesearch.k8s.ner.rst b/docs/source/api/bluesearch.k8s.ner.rst new file mode 100644 index 000000000..eadae999c --- /dev/null +++ b/docs/source/api/bluesearch.k8s.ner.rst @@ -0,0 +1,7 @@ +bluesearch.k8s.ner module +========================= + +.. automodule:: bluesearch.k8s.ner + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/api/bluesearch.k8s.rst b/docs/source/api/bluesearch.k8s.rst index a76de5009..25fbac0e3 100644 --- a/docs/source/api/bluesearch.k8s.rst +++ b/docs/source/api/bluesearch.k8s.rst @@ -9,6 +9,7 @@ Submodules bluesearch.k8s.connect bluesearch.k8s.create_indices + bluesearch.k8s.ner Module contents --------------- From 05cdf4cb02f5a58d2693678017c55baed47b599f Mon Sep 17 00:00:00 2001 From: Emilie Delattre Date: Thu, 27 Oct 2022 16:17:56 +0200 Subject: [PATCH 08/32] Fix type --- src/bluesearch/k8s/create_indices.py | 2 +- src/bluesearch/k8s/ner.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/bluesearch/k8s/create_indices.py b/src/bluesearch/k8s/create_indices.py index 28ff0a5f2..5b0931fe0 100644 --- a/src/bluesearch/k8s/create_indices.py +++ b/src/bluesearch/k8s/create_indices.py @@ -26,7 +26,7 @@ SETTINGS = {"number_of_shards": 2, "number_of_replicas": 1} -MAPPINGS_ARTICLES = { +MAPPINGS_ARTICLES: dict[str, Any] = { "dynamic": "strict", "properties": { "article_id": {"type": "keyword"}, diff --git a/src/bluesearch/k8s/ner.py b/src/bluesearch/k8s/ner.py index c4305cf11..40608b334 100644 --- a/src/bluesearch/k8s/ner.py +++ b/src/bluesearch/k8s/ner.py @@ -198,7 +198,7 @@ def retrieve_csv( now = datetime.now().strftime("%d_%m_%Y_%H_%M") if ner_method == "both": - query = { + query: dict[str, dict[str, Any]] = { "bool": { "filter": [ {"exists": {"field": "ner_ml_json_v2"}}, From 348c11b16355c2abd37d8ab6e2332c7de79c1021 Mon Sep 17 00:00:00 2001 From: Emilie Delattre Date: Thu, 27 Oct 2022 16:46:27 +0200 Subject: [PATCH 09/32] Ignore UserWarning from spacy --- tox.ini | 1 + 1 file changed, 1 insertion(+) diff --git a/tox.ini b/tox.ini index 8264317cc..29ee10dd6 100644 --- a/tox.ini +++ b/tox.ini @@ -157,6 +157,7 @@ filterwarnings = ignore::elastic_transport.SecurityWarning ignore::urllib3.exceptions.InsecureRequestWarning ignore::luigi.parameter.UnconsumedParameterWarning: + ignore::UserWarning:spacy.language: addopts = --cov --cov-config=tox.ini From 7d433fef02bfec7fbea2307b45b5a087a17b3f31 Mon Sep 17 00:00:00 2001 From: Diogo Santos Date: Mon, 31 Oct 2022 14:15:48 +0100 Subject: [PATCH 10/32] add async option --- src/bluesearch/k8s/ner.py | 139 ++++++++++++++++++++++++-------------- 1 file changed, 89 insertions(+), 50 deletions(-) diff --git a/src/bluesearch/k8s/ner.py b/src/bluesearch/k8s/ner.py index c928e86b1..49a00be4b 100644 --- a/src/bluesearch/k8s/ner.py +++ b/src/bluesearch/k8s/ner.py @@ -17,16 +17,20 @@ """Perform Name Entity Recognition (NER) on paragraphs.""" import logging import os +import time +from datetime import datetime +from multiprocessing import Pool from typing import Any import elasticsearch import numpy as np +import pandas as pd import requests import tqdm from dotenv import load_dotenv from elasticsearch.helpers import scan -import pandas as pd -from datetime import datetime + +from bluesearch.k8s import connect load_dotenv() @@ -35,9 +39,12 @@ def run( client: elasticsearch.Elasticsearch, + version: str, index: str = "paragraphs", ner_method: str = "ml", force: bool = False, + n_threads: int = 4, + run_async: bool = True, ) -> None: """Run the NER pipeline on the paragraphs in the database. @@ -64,10 +71,10 @@ def run( if force: query: dict[str, Any] = {"match_all": {}} else: - query = { - "bool": {"must_not": {"exists": {"field": f"ner_{ner_method}_json_v2"}}} - } - paragraph_count = client.count(index=index, query=query)["count"] + query = {"bool": {"must_not": {"term": {f"ner_{ner_method}_version": version}}}} + paragraph_count = client.options(request_timeout=30).count( + index=index, query=query + )["count"] logger.info( f"There are {paragraph_count} paragraphs without NER {ner_method} results." ) @@ -79,70 +86,98 @@ def run( unit=" Paragraphs", desc="Updating NER", ) - for hit in scan(client, query={"query": query}, index=index, scroll="12h"): - try: - results = run_ner_model_remote( - hit["_source"]["text"], + if run_async: + # start a pool of workers + pool = Pool(processes=n_threads) + open_threads = [] + for hit in scan(client, query={"query": query}, index=index, scroll="12h"): + # add a new thread to the pool + res = pool.apply_async( + reu_ner_model_remote, + args=( + hit, + url, + ner_method, + index, + version, + ), + ) + open_threads.append(res) + progress.update(1) + # check if any thread is done + open_threads = [thr for thr in open_threads if not thr.ready()] + # wait if too many threads are running + while len(open_threads) > n_threads: + time.sleep(0.1) + open_threads = [thr for thr in open_threads if not thr.ready()] + # wait for all threads to finish + pool.close() + pool.join() + else: + for hit in scan(client, query={"query": query}, index=index, scroll="12h"): + reu_ner_model_remote( + hit, url, ner_method, + index, + version, ) - client.update( - index=index, doc={f"ner_{ner_method}_json_v2": results}, id=hit["_id"] - ) - progress.update(1) - logger.info( - f"Updated NER for paragraph {hit['_id']}, progress: {progress.n}" - ) - except Exception as e: - print(e) - logger.error(f"Error in paragraph {hit['_id']}, progress: {progress.n}") progress.close() -def run_ner_model_remote(text: str, url: str, source: str) -> list[dict]: - """Run NER on the remote server for a specific paragraph text. +def reu_ner_model_remote( + hit: dict[str, Any], + url: str, + ner_method: str, + index: str, + version: str, +) -> None: + """Perform NER on a paragraph using a remote server.""" + client = connect.connect() - Parameters - ---------- - text - Text to perform NER on. - url - URL of the remote server. - source - Source model of the NER results. - - Returns - ------- - results - List of dictionaries with the NER results. - """ url = "http://" + url + "/predict" response = requests.post( url, headers={"accept": "application/json", "Content-Type": "text/plain"}, - data=text.encode("utf-8"), + data=hit["_source"]["text"].encode("utf-8"), ) if not response.status_code == 200: raise ValueError("Error in the request") results = response.json() - out = [] - for res in results: + if results: + for res in results: + row = {} + row["entity_type"] = res["entity_group"] + row["entity"] = res["word"] + row["start"] = res["start"] + row["end"] = res["end"] + row["score"] = 0 if ner_method == "ruler" else res["score"] + row["source"] = ner_method + out.append(row) + else: + # if no entity is found, return an empty row, + # necessary for ES to find the field in the document row = {} - row["entity_type"] = res["entity_group"] - row["entity"] = res["word"] - row["start"] = res["start"] - row["end"] = res["end"] - row["score"] = 0 if source == "ruler" else res["score"] - row["source"] = source + row["entity_type"] = "Empty" + row["entity"] = "" + row["start"] = 0 + row["end"] = 0 + row["score"] = 0 + row["source"] = ner_method out.append(row) - return out + # update the NER field in the document + client.update(index=index, doc={f"ner_{ner_method}_json_v2": out}, id=hit["_id"]) + # update the version of the NER + client.update( + index=index, doc={f"ner_{ner_method}_version": version}, id=hit["_id"] + ) def handle_conflicts(results_paragraph: list[dict]) -> list[dict]: @@ -191,7 +226,7 @@ def retrieve_csv( ner_method Method to use to perform NER. """ - now = datetime.now().strftime('%d_%m_%Y_%H_%M') + now = datetime.now().strftime("%d_%m_%Y_%H_%M") if ner_method == "both": query = { @@ -241,11 +276,15 @@ def retrieve_csv( results.append(row) progress.update(1) - logger.info( - f"Retrieved NER for paragraph {hit['_id']}, progress: {progress.n}" - ) + logger.info(f"Retrieved NER for paragraph {hit['_id']}, progress: {progress.n}") progress.close() df = pd.DataFrame(results) df.to_csv(f"{output_path}/ner_es_results{ner_method}_{now}.csv", index=False) + + +if __name__ == "__main__": + logging.basicConfig(format="%(levelname)s:%(message)s", level=logging.WARNING) + client = connect.connect() + run(client, version="v2") From de0ac24201d8a727e78be2fa533d7369d3032840 Mon Sep 17 00:00:00 2001 From: Diogo Santos Date: Mon, 31 Oct 2022 14:23:04 +0100 Subject: [PATCH 11/32] small function name fix --- src/bluesearch/k8s/ner.py | 36 +++++++++++++++++++++++++++++++----- 1 file changed, 31 insertions(+), 5 deletions(-) diff --git a/src/bluesearch/k8s/ner.py b/src/bluesearch/k8s/ner.py index 9ea3cd548..a6d660e05 100644 --- a/src/bluesearch/k8s/ner.py +++ b/src/bluesearch/k8s/ner.py @@ -54,12 +54,18 @@ def run( ---------- client Elasticsearch client. + version + Version of the NER pipeline. index Name of the ES index. ner_method Method to use to perform NER. force If True, force the NER to be performed even in all paragraphs. + n_threads + Number of threads to use. + run_async + If True, run the NER asynchronously. """ # get NER method function and url if ner_method == "ml": @@ -95,7 +101,7 @@ def run( for hit in scan(client, query={"query": query}, index=index, scroll="12h"): # add a new thread to the pool res = pool.apply_async( - reu_ner_model_remote, + run_ner_model_remote, args=( hit, url, @@ -117,7 +123,7 @@ def run( pool.join() else: for hit in scan(client, query={"query": query}, index=index, scroll="12h"): - reu_ner_model_remote( + run_ner_model_remote( hit, url, ner_method, @@ -129,14 +135,28 @@ def run( progress.close() -def reu_ner_model_remote( +def run_ner_model_remote( hit: dict[str, Any], url: str, ner_method: str, index: str, version: str, ) -> None: - """Perform NER on a paragraph using a remote server.""" + """Perform NER on a paragraph using a remote server. + + Parameters + ---------- + hit + Elasticsearch hit. + url + URL of the NER server. + ner_method + Method to use to perform NER. + index + Name of the ES index. + version + Version of the NER pipeline. + """ client = connect.connect() url = "http://" + url + "/predict" @@ -183,7 +203,13 @@ def reu_ner_model_remote( def handle_conflicts(results_paragraph: list[dict]) -> list[dict]: - """Handle conflicts between the NER pipeline and the entity ruler.""" + """Handle conflicts between the NER pipeline and the entity ruler. + + Parameters + ---------- + results_paragraph + List of entities found by the NER pipeline. + """ # if there is only one entity, it will be kept if len(results_paragraph) <= 1: return results_paragraph From 93b4016091445524649f04551ce9fee216c2a04b Mon Sep 17 00:00:00 2001 From: Diogo Santos Date: Mon, 31 Oct 2022 14:47:49 +0100 Subject: [PATCH 12/32] fix test ner --- src/bluesearch/k8s/ner.py | 36 ++++++++++++++++++++++++------------ tests/unit/k8s/test_ner.py | 6 +++--- 2 files changed, 27 insertions(+), 15 deletions(-) diff --git a/src/bluesearch/k8s/ner.py b/src/bluesearch/k8s/ner.py index a6d660e05..47452b58a 100644 --- a/src/bluesearch/k8s/ner.py +++ b/src/bluesearch/k8s/ner.py @@ -139,11 +139,12 @@ def run_ner_model_remote( hit: dict[str, Any], url: str, ner_method: str, - index: str, - version: str, -) -> None: + index: str | None = None, + version: str | None = None, + client: elasticsearch.Elasticsearch | None = None, +) -> list[dict[str, Any]] | None: """Perform NER on a paragraph using a remote server. - + Parameters ---------- hit @@ -157,7 +158,12 @@ def run_ner_model_remote( version Version of the NER pipeline. """ - client = connect.connect() + if client is None and index is None and version is None: + logger.info("Running NER in inference mode only.") + elif client is None and index is not None and version is not None: + client = connect.connect() + elif client is None and (index is not None or version is not None): + raise ValueError("Index and version should be both None or not None.") url = "http://" + url + "/predict" @@ -194,17 +200,23 @@ def run_ner_model_remote( row["source"] = ner_method out.append(row) - # update the NER field in the document - client.update(index=index, doc={f"ner_{ner_method}_json_v2": out}, id=hit["_id"]) - # update the version of the NER - client.update( - index=index, doc={f"ner_{ner_method}_version": version}, id=hit["_id"] - ) + if client is not None and index is not None and version is not None: + # update the NER field in the document + client.update( + index=index, doc={f"ner_{ner_method}_json_v2": out}, id=hit["_id"] + ) + # update the version of the NER + client.update( + index=index, doc={f"ner_{ner_method}_version": version}, id=hit["_id"] + ) + return None + else: + return out def handle_conflicts(results_paragraph: list[dict]) -> list[dict]: """Handle conflicts between the NER pipeline and the entity ruler. - + Parameters ---------- results_paragraph diff --git a/tests/unit/k8s/test_ner.py b/tests/unit/k8s/test_ner.py index 0c0f20cd1..7e8734e21 100644 --- a/tests/unit/k8s/test_ner.py +++ b/tests/unit/k8s/test_ner.py @@ -17,7 +17,7 @@ def test_run_ner_model_remote(get_es_client): url = "fake_url" expected_url = "http://" + url + "/predict" - text = "There is a cat and a mouse in the house." + hit = {"_source": {"text": "There is a cat and a mouse in the house."}} expected_response = [ { "entity_group": "ORGANISM", @@ -42,7 +42,7 @@ def test_run_ner_model_remote(get_es_client): json=expected_response, ) - out = run_ner_model_remote(text, url, source="ml") + out = run_ner_model_remote(hit, url, ner_method="ml") assert isinstance(out, list) assert len(out) == 2 @@ -54,6 +54,6 @@ def test_run_ner_model_remote(get_es_client): assert out[0]["start"] == 11 assert out[0]["end"] == 14 - out = run_ner_model_remote(text, url, source="ruler") + out = run_ner_model_remote(hit, url, ner_method="ruler") assert out[0]["score"] == 0 assert out[1]["score"] == 0 From 9b81f2f0d3510d8244c2f2661d40f80b3e8cb352 Mon Sep 17 00:00:00 2001 From: Diogo Santos Date: Mon, 31 Oct 2022 14:50:31 +0100 Subject: [PATCH 13/32] fix test ner --- tests/unit/k8s/test_ner.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/unit/k8s/test_ner.py b/tests/unit/k8s/test_ner.py index 7e8734e21..3bc4629a7 100644 --- a/tests/unit/k8s/test_ner.py +++ b/tests/unit/k8s/test_ner.py @@ -55,5 +55,6 @@ def test_run_ner_model_remote(get_es_client): assert out[0]["end"] == 14 out = run_ner_model_remote(hit, url, ner_method="ruler") + assert isinstance(out, list) assert out[0]["score"] == 0 assert out[1]["score"] == 0 From 537c7fd57970d9891765bb92f2f3fc823fa7d4a9 Mon Sep 17 00:00:00 2001 From: Emilie Delattre Date: Mon, 31 Oct 2022 15:39:04 +0100 Subject: [PATCH 14/32] Start RE --- src/bluesearch/k8s/ner.py | 116 +++++++++++++++++++++++++++++++++++++- 1 file changed, 113 insertions(+), 3 deletions(-) diff --git a/src/bluesearch/k8s/ner.py b/src/bluesearch/k8s/ner.py index 47452b58a..bfd502ae2 100644 --- a/src/bluesearch/k8s/ner.py +++ b/src/bluesearch/k8s/ner.py @@ -21,6 +21,7 @@ import os import time from datetime import datetime +from itertools import product from multiprocessing import Pool from typing import Any @@ -135,14 +136,123 @@ def run( progress.close() -def run_ner_model_remote( +def prepare_text_for_re( + text: str, + subj: dict, + obj: dict, + subject_symbols: tuple[str, str] = ("[[ ", " ]]"), + object_symbols: tuple[str, str] = ("<< ", " >>"), +) -> str: + """Add the subj and obj annotation to the text.""" + if subj["start"] < obj["start"]: + first, second = subj, obj + first_symbols, second_symbols = subject_symbols, object_symbols + else: + first, second = obj, subj + first_symbols, second_symbols = object_symbols, subject_symbols + + attribute = "word" + + part_1 = text[: first["start"]] + part_2 = f"{first_symbols[0]}{first[attribute]}{first_symbols[1]}" + part_3 = text[first["end"]: second["start"]] + part_4 = f"{second_symbols[0]}{second[attribute]}{second_symbols[1]}" + part_5 = text[second["end"]:] + + out = part_1 + part_2 + part_3 + part_4 + part_5 + + return out + + +def run_re_model_remote( hit: dict[str, Any], url: str, - ner_method: str, index: str | None = None, version: str | None = None, - client: elasticsearch.Elasticsearch | None = None, ) -> list[dict[str, Any]] | None: + """Perform RE on a paragraph using a remote server. + + Parameters + ---------- + hit + Elasticsearch hit. + url + URL of the Relation Extraction (RE) server. + index + Name of the ES index. + version + Version of the Relation Extraction pipeline. + """ + + url = "http://" + url + "/predict" + + matrix: list[tuple[str, str]] = [ + ("BRAIN_REGION", "ORGANISM"), + ("CELL_COMPARTMENT", "CELL_TYPE"), + ("CELL_TYPE", "BRAIN_REGION"), + ("CELL_TYPE", "ORGANISM"), + ("GENE", "BRAIN_REGION"), + ("GENE", "CELL_COMPARTMENT"), + ("GENE", "CELL_TYPE"), + ("GENE", "ORGANISM"), + ] + + text = hit["_source"]["text"].encode("utf-8") + ner_ml = hit["_source"]["ner_ml_json_v2"] + ner_ruler = hit["_source"]["ner_ruler_json_v2"] + + results_cleaned = handle_conflicts(ner_ml.extend(ner_ruler)) + + out = [] + + for subj, obj in product(results_cleaned, results_cleaned): + if subj == obj: + continue + if (subj["entity"], obj["entity"]) in matrix: + text_processed = prepare_text_for_re(text, subj, obj) + + response = requests.post( + url, + headers={"accept": "application/json", "Content-Type": "text/plain"}, + data=text_processed, + ) + + if not response.status_code == 200: + raise ValueError("Error in the request") + + result = response.json() + row = {} + if result: + row["label"] = result[0]["label"] + row["score"] = result[0]["score"] + row["subject_entity"] = subj["entity"] + row["subject_word"] = subj["word"] + row["subject_start"] = subj["start"] + row["subject_end"] = subj["end"] + row["subject_source"] = subj["source"] + row["object_entity"] = obj["entity"] + row["object_word"] = obj["word"] + row["object_start"] = obj["start"] + row["object_end"] = obj["end"] + row["object_source"] = obj["source"] + row["source"] = "ml" + out.append(row) + + # update the RE field in the document + client.update(index=index, doc={f"re_ml": out}, id=hit["_id"]) + # update the version of the Relation Extraction + client.update( + index=index, doc={f"re_ml_version": version}, id=hit["_id"] + ) + + +def run_ner_model_remote( + hit: dict[str, Any], + url: str, + ner_method: str, + index: str, + version: str, +) -> None: """Perform NER on a paragraph using a remote server. Parameters From ea7e9b28c98d9241741e663dcaf8d53ebe11ae93 Mon Sep 17 00:00:00 2001 From: Emilie Delattre Date: Mon, 31 Oct 2022 15:42:00 +0100 Subject: [PATCH 15/32] Correct changes mistakes --- src/bluesearch/k8s/ner.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/src/bluesearch/k8s/ner.py b/src/bluesearch/k8s/ner.py index bfd502ae2..dae21479d 100644 --- a/src/bluesearch/k8s/ner.py +++ b/src/bluesearch/k8s/ner.py @@ -167,9 +167,9 @@ def prepare_text_for_re( def run_re_model_remote( hit: dict[str, Any], url: str, - index: str | None = None, - version: str | None = None, -) -> list[dict[str, Any]] | None: + index: str, + version: str, +) -> None: """Perform RE on a paragraph using a remote server. Parameters @@ -250,9 +250,10 @@ def run_ner_model_remote( hit: dict[str, Any], url: str, ner_method: str, - index: str, - version: str, -) -> None: + index: str | None = None, + version: str | None = None, + client: elasticsearch.Elasticsearch | None = None, +) -> list[dict[str, Any]] | None: """Perform NER on a paragraph using a remote server. Parameters From 540e3f24854693d6737eca8fbdad864aee50ed6b Mon Sep 17 00:00:00 2001 From: Diogo Santos Date: Mon, 31 Oct 2022 16:17:37 +0100 Subject: [PATCH 16/32] add test run main ner --- src/bluesearch/k8s/create_indices.py | 3 + tests/unit/k8s/test_ner.py | 91 +++++++++++++++++++++++----- 2 files changed, 80 insertions(+), 14 deletions(-) diff --git a/src/bluesearch/k8s/create_indices.py b/src/bluesearch/k8s/create_indices.py index 5b0931fe0..2971378a5 100644 --- a/src/bluesearch/k8s/create_indices.py +++ b/src/bluesearch/k8s/create_indices.py @@ -53,8 +53,11 @@ "paragraph_id": {"type": "short"}, "text": {"type": "text"}, "ner_ml": {"type": "flattened"}, + "ner_ml_version": {"type": "keyword"}, "ner_ruler": {"type": "flattened"}, + "ner_ruler_version": {"type": "keyword"}, "re": {"type": "flattened"}, + "re_version": {"type": "keyword"}, "is_bad": {"type": "boolean"}, "embedding": { "type": "dense_vector", diff --git a/tests/unit/k8s/test_ner.py b/tests/unit/k8s/test_ner.py index 3bc4629a7..3ddb05fdc 100644 --- a/tests/unit/k8s/test_ner.py +++ b/tests/unit/k8s/test_ner.py @@ -1,23 +1,13 @@ import pytest import responses -from bluesearch.k8s.create_indices import add_index -from bluesearch.k8s.ner import run_ner_model_remote - - -@responses.activate -def test_run_ner_model_remote(get_es_client): - client = get_es_client - - if client is None: - pytest.skip("Elastic search is not available") - - index = "test_index" - add_index(client, index) +from bluesearch.k8s.create_indices import add_index, remove_index +from bluesearch.k8s.ner import run, run_ner_model_remote +@pytest.fixture() +def model_response(): url = "fake_url" expected_url = "http://" + url + "/predict" - hit = {"_source": {"text": "There is a cat and a mouse in the house."}} expected_response = [ { "entity_group": "ORGANISM", @@ -42,6 +32,11 @@ def test_run_ner_model_remote(get_es_client): json=expected_response, ) +@responses.activate +def test_run_ner_model_remote(model_response): + url = "fake_url" + hit = {"_source": {"text": "There is a cat and a mouse in the house."}} + out = run_ner_model_remote(hit, url, ner_method="ml") assert isinstance(out, list) assert len(out) == 2 @@ -58,3 +53,71 @@ def test_run_ner_model_remote(get_es_client): assert isinstance(out, list) assert out[0]["score"] == 0 assert out[1]["score"] == 0 + +@responses.activate +def test_run(get_es_client, model_response): + client = get_es_client + + if client is None: + pytest.skip("Elastic search is not available") + + index = "test_index" + add_index(client, index) + + fake_data = [ + { + "article_id": "1", + "paragraph_id": "1", + "text": "There is a cat and a mouse in the house.", + }, + { + "article_id": "1", + "paragraph_id": "2", + "text": "There is a cat and a mouse in the house.", + }, + { + "article_id": "2", + "paragraph_id": "1", + "text": "There is a cat and a mouse in the house.", + } + ] + + for fd in fake_data: + client.update( + index=index, doc=fd, id=fd["paragraph_id"] + ) + + run(client, "v1", index=index) + + # check that the results are in the database + query = {"bool": {"must": {"term": {"field": {"ner_ml_version": "v1"}}}}} + paragraph_count = client.count(index=index, query=query)["count"] + assert paragraph_count == 3 + + # check that the results are correct + for fd in fake_data: + res = client.get(index=index, id=fd["paragraph_id"]) + assert res["_source"]["ner_ml"][0]["entity"] == "cat" + assert res["_source"]["ner_ml"][0]["score"] == 0.9439833760261536 + assert res["_source"]["ner_ml"][0]["entity_type"] == "ORGANISM" + assert res["_source"]["ner_ml"][0]["start"] == 11 + assert res["_source"]["ner_ml"][0]["end"] == 14 + assert res["_source"]["ner_ml"][0]["source"] == "ml" + + assert res["_source"]["ner_ml"][1]["entity"] == "mouse" + assert res["_source"]["ner_ml"][1]["score"] == 0.9975798726081848 + assert res["_source"]["ner_ml"][1]["entity_type"] == "ORGANISM" + assert res["_source"]["ner_ml"][1]["start"] == 21 + assert res["_source"]["ner_ml"][1]["end"] == 26 + assert res["_source"]["ner_ml"][1]["source"] == "ml" + + assert res["_source"]["ner_ml_version"] == "v1" + + assert res["_source"]["ner_ruler_version"] is None + + # check that all paragraphs have been updated + query = {"bool": {"must_not": {"term": {"field": {"ner_ml_version": "v1"}}}}} + paragraph_count = client.count(index=index, query=query)["count"] + assert paragraph_count == 0 + + remove_index(client, index) From 6f8f2983bd00ad75b91ec8b0c946bb88213f2c96 Mon Sep 17 00:00:00 2001 From: Emilie Delattre Date: Mon, 31 Oct 2022 16:42:17 +0100 Subject: [PATCH 17/32] Remove RE part --- src/bluesearch/k8s/ner.py | 110 -------------------------------------- 1 file changed, 110 deletions(-) diff --git a/src/bluesearch/k8s/ner.py b/src/bluesearch/k8s/ner.py index dae21479d..7374556bd 100644 --- a/src/bluesearch/k8s/ner.py +++ b/src/bluesearch/k8s/ner.py @@ -136,116 +136,6 @@ def run( progress.close() -def prepare_text_for_re( - text: str, - subj: dict, - obj: dict, - subject_symbols: tuple[str, str] = ("[[ ", " ]]"), - object_symbols: tuple[str, str] = ("<< ", " >>"), -) -> str: - """Add the subj and obj annotation to the text.""" - if subj["start"] < obj["start"]: - first, second = subj, obj - first_symbols, second_symbols = subject_symbols, object_symbols - else: - first, second = obj, subj - first_symbols, second_symbols = object_symbols, subject_symbols - - attribute = "word" - - part_1 = text[: first["start"]] - part_2 = f"{first_symbols[0]}{first[attribute]}{first_symbols[1]}" - part_3 = text[first["end"]: second["start"]] - part_4 = f"{second_symbols[0]}{second[attribute]}{second_symbols[1]}" - part_5 = text[second["end"]:] - - out = part_1 + part_2 + part_3 + part_4 + part_5 - - return out - - -def run_re_model_remote( - hit: dict[str, Any], - url: str, - index: str, - version: str, -) -> None: - """Perform RE on a paragraph using a remote server. - - Parameters - ---------- - hit - Elasticsearch hit. - url - URL of the Relation Extraction (RE) server. - index - Name of the ES index. - version - Version of the Relation Extraction pipeline. - """ - - url = "http://" + url + "/predict" - - matrix: list[tuple[str, str]] = [ - ("BRAIN_REGION", "ORGANISM"), - ("CELL_COMPARTMENT", "CELL_TYPE"), - ("CELL_TYPE", "BRAIN_REGION"), - ("CELL_TYPE", "ORGANISM"), - ("GENE", "BRAIN_REGION"), - ("GENE", "CELL_COMPARTMENT"), - ("GENE", "CELL_TYPE"), - ("GENE", "ORGANISM"), - ] - - text = hit["_source"]["text"].encode("utf-8") - ner_ml = hit["_source"]["ner_ml_json_v2"] - ner_ruler = hit["_source"]["ner_ruler_json_v2"] - - results_cleaned = handle_conflicts(ner_ml.extend(ner_ruler)) - - out = [] - - for subj, obj in product(results_cleaned, results_cleaned): - if subj == obj: - continue - if (subj["entity"], obj["entity"]) in matrix: - text_processed = prepare_text_for_re(text, subj, obj) - - response = requests.post( - url, - headers={"accept": "application/json", "Content-Type": "text/plain"}, - data=text_processed, - ) - - if not response.status_code == 200: - raise ValueError("Error in the request") - - result = response.json() - row = {} - if result: - row["label"] = result[0]["label"] - row["score"] = result[0]["score"] - row["subject_entity"] = subj["entity"] - row["subject_word"] = subj["word"] - row["subject_start"] = subj["start"] - row["subject_end"] = subj["end"] - row["subject_source"] = subj["source"] - row["object_entity"] = obj["entity"] - row["object_word"] = obj["word"] - row["object_start"] = obj["start"] - row["object_end"] = obj["end"] - row["object_source"] = obj["source"] - row["source"] = "ml" - out.append(row) - - # update the RE field in the document - client.update(index=index, doc={f"re_ml": out}, id=hit["_id"]) - # update the version of the Relation Extraction - client.update( - index=index, doc={f"re_ml_version": version}, id=hit["_id"] - ) - - def run_ner_model_remote( hit: dict[str, Any], url: str, From 47b99239cd4be409b03c096ba2a06e19b135ac9b Mon Sep 17 00:00:00 2001 From: Emilie Delattre Date: Mon, 31 Oct 2022 16:49:19 +0100 Subject: [PATCH 18/32] remove unused import --- src/bluesearch/k8s/ner.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/bluesearch/k8s/ner.py b/src/bluesearch/k8s/ner.py index 7374556bd..47452b58a 100644 --- a/src/bluesearch/k8s/ner.py +++ b/src/bluesearch/k8s/ner.py @@ -21,7 +21,6 @@ import os import time from datetime import datetime -from itertools import product from multiprocessing import Pool from typing import Any From 9da3e4b78124d881e734e566f8c09652dc26978c Mon Sep 17 00:00:00 2001 From: Diogo Santos Date: Mon, 31 Oct 2022 18:16:05 +0100 Subject: [PATCH 19/32] linter --- tests/unit/k8s/test_ner.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/tests/unit/k8s/test_ner.py b/tests/unit/k8s/test_ner.py index 3ddb05fdc..a3d992fd0 100644 --- a/tests/unit/k8s/test_ner.py +++ b/tests/unit/k8s/test_ner.py @@ -4,6 +4,7 @@ from bluesearch.k8s.create_indices import add_index, remove_index from bluesearch.k8s.ner import run, run_ner_model_remote + @pytest.fixture() def model_response(): url = "fake_url" @@ -32,6 +33,7 @@ def model_response(): json=expected_response, ) + @responses.activate def test_run_ner_model_remote(model_response): url = "fake_url" @@ -54,6 +56,7 @@ def test_run_ner_model_remote(model_response): assert out[0]["score"] == 0 assert out[1]["score"] == 0 + @responses.activate def test_run(get_es_client, model_response): client = get_es_client @@ -79,14 +82,12 @@ def test_run(get_es_client, model_response): "article_id": "2", "paragraph_id": "1", "text": "There is a cat and a mouse in the house.", - } + }, ] for fd in fake_data: - client.update( - index=index, doc=fd, id=fd["paragraph_id"] - ) - + client.update(index=index, doc=fd, id=fd["paragraph_id"]) + run(client, "v1", index=index) # check that the results are in the database From be00832bb268251f72766f171e03a1ce13b0778e Mon Sep 17 00:00:00 2001 From: Diogo Santos Date: Mon, 31 Oct 2022 18:29:13 +0100 Subject: [PATCH 20/32] update test --- tests/unit/k8s/test_ner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/k8s/test_ner.py b/tests/unit/k8s/test_ner.py index a3d992fd0..7fad67787 100644 --- a/tests/unit/k8s/test_ner.py +++ b/tests/unit/k8s/test_ner.py @@ -88,7 +88,7 @@ def test_run(get_es_client, model_response): for fd in fake_data: client.update(index=index, doc=fd, id=fd["paragraph_id"]) - run(client, "v1", index=index) + run(client, "v1", index=index, run_async=False) # check that the results are in the database query = {"bool": {"must": {"term": {"field": {"ner_ml_version": "v1"}}}}} From 2ffff51984b27742eb4792b6a96672a900ed0986 Mon Sep 17 00:00:00 2001 From: Diogo Santos Date: Mon, 31 Oct 2022 19:24:53 +0100 Subject: [PATCH 21/32] fix test run ner --- src/bluesearch/k8s/ner.py | 11 ++++---- tests/unit/k8s/test_ner.py | 57 +++++++++++++++++++++++--------------- 2 files changed, 41 insertions(+), 27 deletions(-) diff --git a/src/bluesearch/k8s/ner.py b/src/bluesearch/k8s/ner.py index 47452b58a..7168cc71f 100644 --- a/src/bluesearch/k8s/ner.py +++ b/src/bluesearch/k8s/ner.py @@ -124,11 +124,12 @@ def run( else: for hit in scan(client, query={"query": query}, index=index, scroll="12h"): run_ner_model_remote( - hit, - url, - ner_method, - index, - version, + hit=hit, + url=url, + ner_method=ner_method, + index=index, + version=version, + client=client, ) progress.update(1) diff --git a/tests/unit/k8s/test_ner.py b/tests/unit/k8s/test_ner.py index 7fad67787..9b0ac75d4 100644 --- a/tests/unit/k8s/test_ner.py +++ b/tests/unit/k8s/test_ner.py @@ -1,7 +1,7 @@ import pytest import responses -from bluesearch.k8s.create_indices import add_index, remove_index +from bluesearch.k8s.create_indices import MAPPINGS_PARAGRAPHS, add_index, remove_index from bluesearch.k8s.ner import run, run_ner_model_remote @@ -33,6 +33,13 @@ def model_response(): json=expected_response, ) + responses.add( + responses.POST, + expected_url, + headers={"accept": "application/json", "Content-Type": "text/plain"}, + json=expected_response, + ) + @responses.activate def test_run_ner_model_remote(model_response): @@ -58,14 +65,21 @@ def test_run_ner_model_remote(model_response): @responses.activate -def test_run(get_es_client, model_response): +def test_run(monkeypatch, get_es_client, model_response): + + BENTOML_NER_ML_URL = "fake_url" + BENTOML_NER_RULER_URL = "fake_url" + + monkeypatch.setenv("BENTOML_NER_ML_URL", BENTOML_NER_ML_URL) + monkeypatch.setenv("BENTOML_NER_RULER_URL", BENTOML_NER_RULER_URL) + client = get_es_client if client is None: pytest.skip("Elastic search is not available") index = "test_index" - add_index(client, index) + add_index(client, index, mappings=MAPPINGS_PARAGRAPHS) fake_data = [ { @@ -80,44 +94,43 @@ def test_run(get_es_client, model_response): }, { "article_id": "2", - "paragraph_id": "1", + "paragraph_id": "3", "text": "There is a cat and a mouse in the house.", }, ] for fd in fake_data: - client.update(index=index, doc=fd, id=fd["paragraph_id"]) + client.index(index=index, document=fd, id=fd["paragraph_id"]) run(client, "v1", index=index, run_async=False) + client.indices.refresh(index=index) # check that the results are in the database - query = {"bool": {"must": {"term": {"field": {"ner_ml_version": "v1"}}}}} + query = {"bool": {"must": {"term": {"ner_ml_version": "v1"}}}} paragraph_count = client.count(index=index, query=query)["count"] assert paragraph_count == 3 # check that the results are correct for fd in fake_data: res = client.get(index=index, id=fd["paragraph_id"]) - assert res["_source"]["ner_ml"][0]["entity"] == "cat" - assert res["_source"]["ner_ml"][0]["score"] == 0.9439833760261536 - assert res["_source"]["ner_ml"][0]["entity_type"] == "ORGANISM" - assert res["_source"]["ner_ml"][0]["start"] == 11 - assert res["_source"]["ner_ml"][0]["end"] == 14 - assert res["_source"]["ner_ml"][0]["source"] == "ml" - - assert res["_source"]["ner_ml"][1]["entity"] == "mouse" - assert res["_source"]["ner_ml"][1]["score"] == 0.9975798726081848 - assert res["_source"]["ner_ml"][1]["entity_type"] == "ORGANISM" - assert res["_source"]["ner_ml"][1]["start"] == 21 - assert res["_source"]["ner_ml"][1]["end"] == 26 - assert res["_source"]["ner_ml"][1]["source"] == "ml" + assert res["_source"]["ner_ml_json_v2"][0]["entity"] == "cat" + assert res["_source"]["ner_ml_json_v2"][0]["score"] == 0.9439833760261536 + assert res["_source"]["ner_ml_json_v2"][0]["entity_type"] == "ORGANISM" + assert res["_source"]["ner_ml_json_v2"][0]["start"] == 11 + assert res["_source"]["ner_ml_json_v2"][0]["end"] == 14 + assert res["_source"]["ner_ml_json_v2"][0]["source"] == "ml" + + assert res["_source"]["ner_ml_json_v2"][1]["entity"] == "mouse" + assert res["_source"]["ner_ml_json_v2"][1]["score"] == 0.9975798726081848 + assert res["_source"]["ner_ml_json_v2"][1]["entity_type"] == "ORGANISM" + assert res["_source"]["ner_ml_json_v2"][1]["start"] == 21 + assert res["_source"]["ner_ml_json_v2"][1]["end"] == 26 + assert res["_source"]["ner_ml_json_v2"][1]["source"] == "ml" assert res["_source"]["ner_ml_version"] == "v1" - assert res["_source"]["ner_ruler_version"] is None - # check that all paragraphs have been updated - query = {"bool": {"must_not": {"term": {"field": {"ner_ml_version": "v1"}}}}} + query = {"bool": {"must_not": {"term": {"ner_ml_version": "v1"}}}} paragraph_count = client.count(index=index, query=query)["count"] assert paragraph_count == 0 From 271fa446b30adf670f79fda435b60336c701ee16 Mon Sep 17 00:00:00 2001 From: Emilie Delattre Date: Tue, 1 Nov 2022 10:05:22 +0100 Subject: [PATCH 22/32] Add time.sleep to see if test is passing --- tests/unit/k8s/test_ner.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/unit/k8s/test_ner.py b/tests/unit/k8s/test_ner.py index 9b0ac75d4..6a4f77eca 100644 --- a/tests/unit/k8s/test_ner.py +++ b/tests/unit/k8s/test_ner.py @@ -1,5 +1,6 @@ import pytest import responses +import time from bluesearch.k8s.create_indices import MAPPINGS_PARAGRAPHS, add_index, remove_index from bluesearch.k8s.ner import run, run_ner_model_remote @@ -105,6 +106,8 @@ def test_run(monkeypatch, get_es_client, model_response): run(client, "v1", index=index, run_async=False) client.indices.refresh(index=index) + time.sleep(60) + # check that the results are in the database query = {"bool": {"must": {"term": {"ner_ml_version": "v1"}}}} paragraph_count = client.count(index=index, query=query)["count"] From 365b950f035f824d8941d782e73c0fec8c47fd08 Mon Sep 17 00:00:00 2001 From: Emilie Delattre Date: Tue, 1 Nov 2022 10:16:11 +0100 Subject: [PATCH 23/32] Change the place of time.sleep --- tests/unit/k8s/test_ner.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/unit/k8s/test_ner.py b/tests/unit/k8s/test_ner.py index 6a4f77eca..543439976 100644 --- a/tests/unit/k8s/test_ner.py +++ b/tests/unit/k8s/test_ner.py @@ -106,11 +106,10 @@ def test_run(monkeypatch, get_es_client, model_response): run(client, "v1", index=index, run_async=False) client.indices.refresh(index=index) - time.sleep(60) - # check that the results are in the database query = {"bool": {"must": {"term": {"ner_ml_version": "v1"}}}} paragraph_count = client.count(index=index, query=query)["count"] + time.sleep(60) assert paragraph_count == 3 # check that the results are correct From 6ac61807afa15749670c889507880860bb4994f3 Mon Sep 17 00:00:00 2001 From: Diogo Santos Date: Tue, 1 Nov 2022 10:32:46 +0100 Subject: [PATCH 24/32] correct test index name --- tests/conftest.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/conftest.py b/tests/conftest.py index 43d0f97ba..12c65b203 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -518,5 +518,7 @@ def get_es_client(monkeypatch): if client is not None: for index in client.indices.get_alias().keys(): - if index in ["articles", "paragraphs", "test_index"]: + if index in ["test_articles", "test_paragraphs", "test_index"]: remove_index(client, index) + time.sleep(1) + From 812e76503e694f87cd5b432a08e20d9de0fe42bc Mon Sep 17 00:00:00 2001 From: Emilie Delattre Date: Tue, 1 Nov 2022 11:56:46 +0100 Subject: [PATCH 25/32] Add refresh before running run --- tests/unit/k8s/test_ner.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/unit/k8s/test_ner.py b/tests/unit/k8s/test_ner.py index 543439976..b922e1c76 100644 --- a/tests/unit/k8s/test_ner.py +++ b/tests/unit/k8s/test_ner.py @@ -1,6 +1,5 @@ import pytest import responses -import time from bluesearch.k8s.create_indices import MAPPINGS_PARAGRAPHS, add_index, remove_index from bluesearch.k8s.ner import run, run_ner_model_remote @@ -103,13 +102,13 @@ def test_run(monkeypatch, get_es_client, model_response): for fd in fake_data: client.index(index=index, document=fd, id=fd["paragraph_id"]) + client.indices.refresh(index=index) run(client, "v1", index=index, run_async=False) client.indices.refresh(index=index) # check that the results are in the database query = {"bool": {"must": {"term": {"ner_ml_version": "v1"}}}} paragraph_count = client.count(index=index, query=query)["count"] - time.sleep(60) assert paragraph_count == 3 # check that the results are correct From c60a9feae898106e45eaddac3bc9ac4ab9faab33 Mon Sep 17 00:00:00 2001 From: Emilie Delattre Date: Tue, 1 Nov 2022 13:17:45 +0100 Subject: [PATCH 26/32] Change the mapping --- src/bluesearch/k8s/create_indices.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/bluesearch/k8s/create_indices.py b/src/bluesearch/k8s/create_indices.py index 2971378a5..ffa4f7c54 100644 --- a/src/bluesearch/k8s/create_indices.py +++ b/src/bluesearch/k8s/create_indices.py @@ -52,9 +52,9 @@ "section_name": {"type": "keyword"}, "paragraph_id": {"type": "short"}, "text": {"type": "text"}, - "ner_ml": {"type": "flattened"}, + "ner_ml_json_v2": {"type": "flattened"}, "ner_ml_version": {"type": "keyword"}, - "ner_ruler": {"type": "flattened"}, + "ner_ruler_json_v2": {"type": "flattened"}, "ner_ruler_version": {"type": "keyword"}, "re": {"type": "flattened"}, "re_version": {"type": "keyword"}, From e340b362f7a68f0751c9b1b1432a5a47a9dd7ce2 Mon Sep 17 00:00:00 2001 From: Emilie Delattre Date: Tue, 1 Nov 2022 13:20:13 +0100 Subject: [PATCH 27/32] Remove time.sleep in get_es_client --- tests/conftest.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 12c65b203..a3cac711a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -520,5 +520,3 @@ def get_es_client(monkeypatch): for index in client.indices.get_alias().keys(): if index in ["test_articles", "test_paragraphs", "test_index"]: remove_index(client, index) - time.sleep(1) - From b4f6746fcf6adf08c083a83b7a3a9ea6bc820402 Mon Sep 17 00:00:00 2001 From: Diogo Santos Date: Wed, 2 Nov 2022 10:10:06 +0100 Subject: [PATCH 28/32] small fix handle conflicts ner --- src/bluesearch/k8s/ner.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/bluesearch/k8s/ner.py b/src/bluesearch/k8s/ner.py index 7168cc71f..959593197 100644 --- a/src/bluesearch/k8s/ner.py +++ b/src/bluesearch/k8s/ner.py @@ -236,8 +236,8 @@ def handle_conflicts(results_paragraph: list[dict]) -> list[dict]: array = np.zeros(max([x["end"] for x in temp])) for res in temp: - add_one = 1 if res["word"][0] == " " else 0 - sub_one = 1 if res["word"][-1] == " " else 0 + add_one = 1 if res["entity"][0] == " " else 0 + sub_one = 1 if res["entity"][-1] == " " else 0 if len(results_cleaned) == 0: results_cleaned.append(res) array[res["start"] + add_one : res["end"] - sub_one] = 1 From b99bfb041d8ee1b72ada343a260ac20269539e86 Mon Sep 17 00:00:00 2001 From: Diogo Santos Date: Wed, 2 Nov 2022 10:13:08 +0100 Subject: [PATCH 29/32] add test handle conflicts ner --- tests/unit/k8s/test_ner.py | 364 ++++++++++++++++++++++++++++++++++++- 1 file changed, 363 insertions(+), 1 deletion(-) diff --git a/tests/unit/k8s/test_ner.py b/tests/unit/k8s/test_ner.py index b922e1c76..106e8ac0b 100644 --- a/tests/unit/k8s/test_ner.py +++ b/tests/unit/k8s/test_ner.py @@ -2,7 +2,7 @@ import responses from bluesearch.k8s.create_indices import MAPPINGS_PARAGRAPHS, add_index, remove_index -from bluesearch.k8s.ner import run, run_ner_model_remote +from bluesearch.k8s.ner import run, run_ner_model_remote, handle_conflicts @pytest.fixture() @@ -136,3 +136,365 @@ def test_run(monkeypatch, get_es_client, model_response): assert paragraph_count == 0 remove_index(client, index) + +@pytest.mark.parametrize( + ("raw_ents", "cleaned_ents"), + [ + pytest.param([], [], id="empty list"), + pytest.param( + [ + { + "start": 1, + "end": 5, + "word": "word", + }, + ], + [ + { + "start": 1, + "end": 5, + "word": "word", + }, + ], + id="one element", + ), + pytest.param( + [ + { + "start": 1, + "end": 5, + "word": "word", + "source": "RULES", + }, + { + "start": 10, + "end": 15, + "word": "word", + "source": "ML", + }, + ], + [ + { + "start": 1, + "end": 5, + "word": "word", + "source": "RULES", + }, + { + "start": 10, + "end": 15, + "word": "word", + "source": "ML", + }, + ], + id="no overlap", + ), + pytest.param( + [ + { + "start": 1, + "end": 5, + "word": "word", + "source": "RULES", + }, + { + "start": 1, + "end": 5, + "word": "word", + "source": "ML", + }, + ], + [ + { + "start": 1, + "end": 5, + "word": "word", + "source": "ML", + }, + ], + id="perfect overlap", + ), + pytest.param( + [ + { + "start": 1, + "end": 5, + "word": "word", + "source": "RULES", + }, + { + "start": 2, + "end": 20, + "word": "word", + "source": "ML", + }, + ], + [ + { + "start": 2, + "end": 20, + "word": "word", + "source": "ML", + }, + ], + id="overlap - ML longer", + ), + pytest.param( + [ + { + "start": 1, + "end": 50, + "word": "word", + "source": "RULES", + }, + { + "start": 25, + "end": 60, + "word": "word", + "source": "ML", + }, + ], + [ + { + "start": 1, + "end": 50, + "word": "word", + "source": "RULES", + }, + ], + id="overlap - RULES longer", + ), + pytest.param( + [ + { + "start": 1, + "end": 50, + "word": "word", + "source": "RULES", + }, + { + "start": 25, + "end": 40, + "word": "word", + "source": "ML", + }, + ], + [ + { + "start": 1, + "end": 50, + "word": "word", + "source": "RULES", + }, + ], + id="overlap - ML subset of RULES", + ), + pytest.param( + [ + { + "start": 4, + "end": 24, + "word": "word", + "source": "RULES", + }, + { + "start": 2, + "end": 40, + "word": "word", + "source": "ML", + }, + ], + [ + { + "start": 2, + "end": 40, + "word": "word", + "source": "ML", + }, + ], + id="overlap - RULES subset of ML", + ), + pytest.param( + [ + { + "start": 10, + "end": 30, + "word": "word", + "source": "RULES", + }, + { + "start": 20, + "end": 40, + "word": "word", + "source": "ML", + }, + ], + [ + { + "start": 20, + "end": 40, + "word": "word", + "source": "ML", + }, + ], + id="overlap - same length", + ), + pytest.param( + [ + { + "start": 10, + "end": 30, + "word": "word", + "source": "RULES", + }, + { + "start": 15, + "end": 20, + "word": "word", + "source": "ML", + }, + { + "start": 23, + "end": 34, + "word": "word", + "source": "ML", + }, + { + "start": 31, + "end": 33, + "word": "word", + "source": "RULES", + }, + ], + [ + { + "start": 10, + "end": 30, + "word": "word", + "source": "RULES", + }, + { + "start": 31, + "end": 33, + "word": "word", + "source": "RULES", + }, + ], + id="more entries - 1", + ), + pytest.param( + [ + { + "start": 10, + "end": 30, + "word": "word", + "source": "RULES", + }, + { + "start": 20, + "end": 50, + "word": "word", + "source": "ML", + }, + { + "start": 35, + "end": 100, + "word": "word", + "source": "RULES", + }, + ], + [ + { + "start": 10, + "end": 30, + "word": "word", + "source": "RULES", + }, + { + "start": 35, + "end": 100, + "word": "word", + "source": "RULES", + }, + ], + id="more entries - 2", + ), + pytest.param( + [ + { + "start": 10, + "end": 12, + "word": "word", + "source": "RULES", + }, + { + "start": 10, + "end": 12, + "word": "word", + "source": "ML", + }, + { + "start": 35, + "end": 100, + "word": "word", + "source": "RULES", + }, + ], + [ + { + "start": 10, + "end": 12, + "word": "word", + "source": "ML", + }, + { + "start": 35, + "end": 100, + "word": "word", + "source": "RULES", + }, + ], + id="entries with only 2 chars", + ), + pytest.param( + [ + { + "start": 10, + "end": 12, + "word": "word", + "source": "RULES", + }, + { + "start": 10, + "end": 12, + "word": "word", + "source": "ML", + }, + { + "start": 12, + "end": 15, + "word": " word", + "source": "RULES", + }, + ], + [ + { + "start": 10, + "end": 12, + "word": "word", + "source": "ML", + }, + { + "start": 12, + "end": 15, + "word": " word", + "source": "RULES", + }, + ], + id="overlap with whitespace", + ), + ], +) +def test_handle_conflicts(raw_ents, cleaned_ents): + """Test handle_conflicts function.""" + assert cleaned_ents == handle_conflicts(raw_ents) \ No newline at end of file From 2fa94e1c90088f957529bc792ced9e39a06c9a0a Mon Sep 17 00:00:00 2001 From: Diogo Santos Date: Wed, 2 Nov 2022 10:13:47 +0100 Subject: [PATCH 30/32] small fix lint --- tests/unit/k8s/test_ner.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/unit/k8s/test_ner.py b/tests/unit/k8s/test_ner.py index 106e8ac0b..03b69691f 100644 --- a/tests/unit/k8s/test_ner.py +++ b/tests/unit/k8s/test_ner.py @@ -137,6 +137,7 @@ def test_run(monkeypatch, get_es_client, model_response): remove_index(client, index) + @pytest.mark.parametrize( ("raw_ents", "cleaned_ents"), [ @@ -497,4 +498,4 @@ def test_run(monkeypatch, get_es_client, model_response): ) def test_handle_conflicts(raw_ents, cleaned_ents): """Test handle_conflicts function.""" - assert cleaned_ents == handle_conflicts(raw_ents) \ No newline at end of file + assert cleaned_ents == handle_conflicts(raw_ents) From 7a1cc4b007f13ee79b422114b6d761ac26617f7b Mon Sep 17 00:00:00 2001 From: Diogo Santos Date: Wed, 2 Nov 2022 10:28:06 +0100 Subject: [PATCH 31/32] fix isort --- tests/unit/k8s/test_ner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/k8s/test_ner.py b/tests/unit/k8s/test_ner.py index 03b69691f..d4fd816c7 100644 --- a/tests/unit/k8s/test_ner.py +++ b/tests/unit/k8s/test_ner.py @@ -2,7 +2,7 @@ import responses from bluesearch.k8s.create_indices import MAPPINGS_PARAGRAPHS, add_index, remove_index -from bluesearch.k8s.ner import run, run_ner_model_remote, handle_conflicts +from bluesearch.k8s.ner import handle_conflicts, run, run_ner_model_remote @pytest.fixture() From dc26ad6f569329abd45e4a61d3d21ba79a8546fa Mon Sep 17 00:00:00 2001 From: Diogo Santos Date: Wed, 2 Nov 2022 10:30:00 +0100 Subject: [PATCH 32/32] update test --- tests/unit/k8s/test_ner.py | 90 +++++++++++++++++++------------------- 1 file changed, 45 insertions(+), 45 deletions(-) diff --git a/tests/unit/k8s/test_ner.py b/tests/unit/k8s/test_ner.py index d4fd816c7..db05c78ac 100644 --- a/tests/unit/k8s/test_ner.py +++ b/tests/unit/k8s/test_ner.py @@ -147,14 +147,14 @@ def test_run(monkeypatch, get_es_client, model_response): { "start": 1, "end": 5, - "word": "word", + "entity": "entity", }, ], [ { "start": 1, "end": 5, - "word": "word", + "entity": "entity", }, ], id="one element", @@ -164,13 +164,13 @@ def test_run(monkeypatch, get_es_client, model_response): { "start": 1, "end": 5, - "word": "word", + "entity": "entity", "source": "RULES", }, { "start": 10, "end": 15, - "word": "word", + "entity": "entity", "source": "ML", }, ], @@ -178,13 +178,13 @@ def test_run(monkeypatch, get_es_client, model_response): { "start": 1, "end": 5, - "word": "word", + "entity": "entity", "source": "RULES", }, { "start": 10, "end": 15, - "word": "word", + "entity": "entity", "source": "ML", }, ], @@ -195,13 +195,13 @@ def test_run(monkeypatch, get_es_client, model_response): { "start": 1, "end": 5, - "word": "word", + "entity": "entity", "source": "RULES", }, { "start": 1, "end": 5, - "word": "word", + "entity": "entity", "source": "ML", }, ], @@ -209,7 +209,7 @@ def test_run(monkeypatch, get_es_client, model_response): { "start": 1, "end": 5, - "word": "word", + "entity": "entity", "source": "ML", }, ], @@ -220,13 +220,13 @@ def test_run(monkeypatch, get_es_client, model_response): { "start": 1, "end": 5, - "word": "word", + "entity": "entity", "source": "RULES", }, { "start": 2, "end": 20, - "word": "word", + "entity": "entity", "source": "ML", }, ], @@ -234,7 +234,7 @@ def test_run(monkeypatch, get_es_client, model_response): { "start": 2, "end": 20, - "word": "word", + "entity": "entity", "source": "ML", }, ], @@ -245,13 +245,13 @@ def test_run(monkeypatch, get_es_client, model_response): { "start": 1, "end": 50, - "word": "word", + "entity": "entity", "source": "RULES", }, { "start": 25, "end": 60, - "word": "word", + "entity": "entity", "source": "ML", }, ], @@ -259,7 +259,7 @@ def test_run(monkeypatch, get_es_client, model_response): { "start": 1, "end": 50, - "word": "word", + "entity": "entity", "source": "RULES", }, ], @@ -270,13 +270,13 @@ def test_run(monkeypatch, get_es_client, model_response): { "start": 1, "end": 50, - "word": "word", + "entity": "entity", "source": "RULES", }, { "start": 25, "end": 40, - "word": "word", + "entity": "entity", "source": "ML", }, ], @@ -284,7 +284,7 @@ def test_run(monkeypatch, get_es_client, model_response): { "start": 1, "end": 50, - "word": "word", + "entity": "entity", "source": "RULES", }, ], @@ -295,13 +295,13 @@ def test_run(monkeypatch, get_es_client, model_response): { "start": 4, "end": 24, - "word": "word", + "entity": "entity", "source": "RULES", }, { "start": 2, "end": 40, - "word": "word", + "entity": "entity", "source": "ML", }, ], @@ -309,7 +309,7 @@ def test_run(monkeypatch, get_es_client, model_response): { "start": 2, "end": 40, - "word": "word", + "entity": "entity", "source": "ML", }, ], @@ -320,13 +320,13 @@ def test_run(monkeypatch, get_es_client, model_response): { "start": 10, "end": 30, - "word": "word", + "entity": "entity", "source": "RULES", }, { "start": 20, "end": 40, - "word": "word", + "entity": "entity", "source": "ML", }, ], @@ -334,7 +334,7 @@ def test_run(monkeypatch, get_es_client, model_response): { "start": 20, "end": 40, - "word": "word", + "entity": "entity", "source": "ML", }, ], @@ -345,25 +345,25 @@ def test_run(monkeypatch, get_es_client, model_response): { "start": 10, "end": 30, - "word": "word", + "entity": "entity", "source": "RULES", }, { "start": 15, "end": 20, - "word": "word", + "entity": "entity", "source": "ML", }, { "start": 23, "end": 34, - "word": "word", + "entity": "entity", "source": "ML", }, { "start": 31, "end": 33, - "word": "word", + "entity": "entity", "source": "RULES", }, ], @@ -371,13 +371,13 @@ def test_run(monkeypatch, get_es_client, model_response): { "start": 10, "end": 30, - "word": "word", + "entity": "entity", "source": "RULES", }, { "start": 31, "end": 33, - "word": "word", + "entity": "entity", "source": "RULES", }, ], @@ -388,19 +388,19 @@ def test_run(monkeypatch, get_es_client, model_response): { "start": 10, "end": 30, - "word": "word", + "entity": "entity", "source": "RULES", }, { "start": 20, "end": 50, - "word": "word", + "entity": "entity", "source": "ML", }, { "start": 35, "end": 100, - "word": "word", + "entity": "entity", "source": "RULES", }, ], @@ -408,13 +408,13 @@ def test_run(monkeypatch, get_es_client, model_response): { "start": 10, "end": 30, - "word": "word", + "entity": "entity", "source": "RULES", }, { "start": 35, "end": 100, - "word": "word", + "entity": "entity", "source": "RULES", }, ], @@ -425,19 +425,19 @@ def test_run(monkeypatch, get_es_client, model_response): { "start": 10, "end": 12, - "word": "word", + "entity": "entity", "source": "RULES", }, { "start": 10, "end": 12, - "word": "word", + "entity": "entity", "source": "ML", }, { "start": 35, "end": 100, - "word": "word", + "entity": "entity", "source": "RULES", }, ], @@ -445,13 +445,13 @@ def test_run(monkeypatch, get_es_client, model_response): { "start": 10, "end": 12, - "word": "word", + "entity": "entity", "source": "ML", }, { "start": 35, "end": 100, - "word": "word", + "entity": "entity", "source": "RULES", }, ], @@ -462,19 +462,19 @@ def test_run(monkeypatch, get_es_client, model_response): { "start": 10, "end": 12, - "word": "word", + "entity": "entity", "source": "RULES", }, { "start": 10, "end": 12, - "word": "word", + "entity": "entity", "source": "ML", }, { "start": 12, "end": 15, - "word": " word", + "entity": " entity", "source": "RULES", }, ], @@ -482,13 +482,13 @@ def test_run(monkeypatch, get_es_client, model_response): { "start": 10, "end": 12, - "word": "word", + "entity": "entity", "source": "ML", }, { "start": 12, "end": 15, - "word": " word", + "entity": " entity", "source": "RULES", }, ],