diff --git a/backend/requirements.txt b/backend/requirements.txt index c0f4d79..7e19f86 100644 Binary files a/backend/requirements.txt and b/backend/requirements.txt differ diff --git a/backend/src/app.py b/backend/src/app.py index 0632587..98ffba6 100644 --- a/backend/src/app.py +++ b/backend/src/app.py @@ -4,42 +4,12 @@ from docarray import DocumentArray from docarray.document.generators import from_csv -from backend_config import papers_data_path, papers_data_url +from config import config, create_parser from flows import index_flow, search_flow from helpers import download_csv, log, maximise_csv_field_size_limit - -# boolean args: -# https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse/36031646 -def str2bool(v): - if isinstance(v, bool): - return v - if v.lower() in ("yes", "true", "t", "y", "1"): - return True - elif v.lower() in ("no", "false", "f", "n", "0"): - return False - else: - raise argparse.ArgumentTypeError("Boolean value expected.") - - -def get_args(): - # Command line arguments definitions - parser = argparse.ArgumentParser() - parser.add_argument( - "--index", - dest="index", - default=False, - action="store_true", - help="index the available documents", - ) - parser.add_argument( - "--n", - type=int, - default=0, - help="when `--index` is used, specifies the number of documnts to index (0 indexes the full dataset)", - ) - - return parser.parse_args() +papers_data_path = config["papers_data_path"].get() +papers_data_url = config["papers_data_url"].get() def index(n): @@ -70,12 +40,13 @@ def index(n): indexer.index(papers, request_size=32) -args = get_args() +args = create_parser() +config.set_args(args) if args.index: index(args.n) -# running the search/finetuning flow as a service + # running the search/finetuning flow as a service flow = search_flow() flow.expose_endpoint("/finetune", summary="Finetune documents.", tags=["Finetuning"]) diff --git a/backend/src/config.py b/backend/src/config.py new file mode 100644 index 0000000..2eb5fc1 --- /dev/null +++ b/backend/src/config.py @@ -0,0 +1,30 @@ +from argparse import ArgumentParser +from pathlib import Path + +from confuse import Configuration + +config = Configuration("backend") + +file_abs_path = Path(__file__).parent.resolve() +config.set_file(file_abs_path / "config.yaml") + + +def create_parser(): + + parser = ArgumentParser() + + parser.add_argument( + "--index", + dest="index", + default=False, + action="store_true", + help="index the available documents", + ) + parser.add_argument( + "--n", + type=int, + default=0, + help="when `--index` is used, specifies the number of documnts to index (0 indexes the full dataset)", + ) + + return parser.parse_args() diff --git a/backend/src/backend_config.py b/backend/src/config.yaml similarity index 54% rename from backend/src/backend_config.py rename to backend/src/config.yaml index 0d7a5eb..b1ea3b4 100644 --- a/backend/src/backend_config.py +++ b/backend/src/config.yaml @@ -1,17 +1,18 @@ + # Hugging Face: https://huggingface.co/sentence-transformers/allenai-specter -embedding_model = "sentence-transformers/allenai-specter" +embedding_model : "sentence-transformers/allenai-specter" # dataset link -papers_data_url = "http://www.lri.fr/owncloud/index.php/s/OO987IvsoKwWI3l/download" +papers_data_url : "http://www.lri.fr/owncloud/index.php/s/OO987IvsoKwWI3l/download" # Number of search results to show. -top_k = 5 +top_k : 5 # Protein file path relative to root. -papers_data_path = "data/papers.csv" +papers_data_path : "data/papers.csv" # Prints logs to command line if true. -print_logs = True +print_logs : True # Search service port -search_port = 8020 +search_port : 8020 diff --git a/backend/src/executors.py b/backend/src/executors.py index 5db6203..54b173b 100644 --- a/backend/src/executors.py +++ b/backend/src/executors.py @@ -1,14 +1,14 @@ -import re import os -from typing import Sequence, List, Tuple -from sentence_transformers import SentenceTransformer, InputExample, losses +from jina import DocumentArray, Executor, requests +from sentence_transformers import InputExample, SentenceTransformer, losses from torch.utils.data import DataLoader -from jina import Executor, requests, Document, DocumentArray -from backend_config import top_k, embedding_model +from config import config from helpers import log +embedding_model = config["embedding_model"].get() + def get_model_dir(): model_dir = f"./models/{embedding_model}" diff --git a/backend/src/flows.py b/backend/src/flows.py index 7f9501c..fa89182 100644 --- a/backend/src/flows.py +++ b/backend/src/flows.py @@ -1,7 +1,7 @@ from jina import Flow +from config import config from executors import SpecterExecutor -from backend_config import top_k, search_port # Using a standard indexer: https://hub.jina.ai/executor/zb38xlt4 indexer = "jinahub://SimpleIndexer" @@ -21,14 +21,14 @@ def index_flow(): def search_flow(): flow = ( - Flow(port_expose=search_port, protocol="http") + Flow(port_expose=config["search_port"].get(), protocol="http") .add(uses=SpecterExecutor) .add( uses=indexer, uses_with={ "match_args": { "metric": "cosine", - "limit": top_k, + "limit": config["top_k"].get(), }, }, ) diff --git a/backend/src/helpers.py b/backend/src/helpers.py index 4ff6f09..476b5fe 100644 --- a/backend/src/helpers.py +++ b/backend/src/helpers.py @@ -1,10 +1,11 @@ import os +from argparse import ArgumentTypeError from csv import field_size_limit from sys import maxsize import requests -from backend_config import print_logs +from config import config def download_csv(url, fp): @@ -15,15 +16,27 @@ def download_csv(url, fp): def log(message): - if print_logs: + if config["print_logs"].get(): print(message) def maximise_csv_field_size_limit(maxInt=maxsize): - while True: try: field_size_limit(maxInt) break except OverflowError: maxInt = int(maxInt / 10) + + +# boolean args: +# https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse/36031646 +def str2bool(v): + if isinstance(v, bool): + return v + if v.lower() in ("yes", "true", "t", "y", "1"): + return True + elif v.lower() in ("no", "false", "f", "n", "0"): + return False + else: + raise ArgumentTypeError("Boolean value expected.")