diff --git a/src/bluesearch/entrypoint/database/add_es.py b/src/bluesearch/entrypoint/database/add_es.py index 2841d6cdb..ace942f68 100644 --- a/src/bluesearch/entrypoint/database/add_es.py +++ b/src/bluesearch/entrypoint/database/add_es.py @@ -115,12 +115,25 @@ def bulk_paragraphs( for inp in inputs: serialized = inp.read_text("utf-8") article = Article.from_json(serialized) + # add abstract to paragraphs in order to be able to search for abstracts + for i, abstract in enumerate(article.abstract): + doc = { + "_index": "paragraphs", + "_source": { + "article_id": article.uid, + "section": "abstract", + "text": abstract, + "paragraph_id": i, + }, + } + yield doc + # add body paragraphs for ppos, (section, text) in enumerate(article.section_paragraphs): doc = { "_index": "paragraphs", "_source": { "article_id": article.uid, - "section_name": section, + "section": section, "text": text, "paragraph_id": ppos, }, @@ -160,7 +173,7 @@ def run( if resp[0] == 0: raise RuntimeWarning(f"No articles were loaded to ES from '{parsed_path}'!") - logger.info("Uploading articles to the database...") + logger.info("Uploading paragraphs to the database...") progress = tqdm.tqdm( desc="Uploading paragraphs", total=len(inputs), unit="articles" ) diff --git a/src/bluesearch/k8s/create_indices.py b/src/bluesearch/k8s/create_indices.py index 00db98efe..3b195f04a 100644 --- a/src/bluesearch/k8s/create_indices.py +++ b/src/bluesearch/k8s/create_indices.py @@ -49,7 +49,7 @@ "dynamic": "strict", "properties": { "article_id": {"type": "keyword"}, - "section_name": {"type": "keyword"}, + "section": {"type": "keyword"}, "paragraph_id": {"type": "short"}, "text": {"type": "text"}, "is_bad": {"type": "boolean"}, diff --git a/tests/unit/entrypoint/database/test_add_es.py b/tests/unit/entrypoint/database/test_add_es.py index cb44b4adb..84e249757 100644 --- a/tests/unit/entrypoint/database/test_add_es.py +++ b/tests/unit/entrypoint/database/test_add_es.py @@ -72,7 +72,7 @@ def test(get_es_client: Elasticsearch, tmp_path: Path) -> None: # verify paragraphs resp = client.search(index="paragraphs", query={"match_all": {}}) - assert resp["hits"]["total"]["value"] == 4 + assert resp["hits"]["total"]["value"] == 8 all_docs = set() for doc in resp["hits"]["hits"]: @@ -80,14 +80,18 @@ def test(get_es_client: Elasticsearch, tmp_path: Path) -> None: ( doc["_source"]["article_id"], doc["_source"]["paragraph_id"], - doc["_source"]["section_name"], + doc["_source"]["section"], doc["_source"]["text"], ) ) all_docs_expected = { + ("1", 0, "abstract", "some test abstract"), + ("1", 1, "abstract", "abcd"), ("1", 0, "intro", "some test section_paragraphs 1client"), ("1", 1, "summary", "some test section_paragraphs 2"), + ("2", 0, "abstract", "dsaklf"), + ("2", 1, "abstract", "abcd"), ("2", 0, "intro", "some TESTTT section_paragraphs 1client"), ("2", 1, "summary", "some other test section_paragraphs 2"), }