diff --git a/compose.yml b/compose.yml index 2fd45c1..96a2a8b 100644 --- a/compose.yml +++ b/compose.yml @@ -338,6 +338,20 @@ services: volumes: - ./logs:/app/logs - ./.env:/app/.env + arax_pathfinder: + container_name: arax_pathfinder + build: + context: . + dockerfile: workers/arax_pathfinder/Dockerfile + restart: unless-stopped + depends_on: + shepherd_db: + condition: service_healthy + shepherd_broker: + condition: service_healthy + volumes: + - ./logs:/app/logs + - ./.env:/app/.env arax_rank: container_name: arax_rank diff --git a/shepherd_utils/config.py b/shepherd_utils/config.py index 21551d9..1178cf5 100644 --- a/shepherd_utils/config.py +++ b/shepherd_utils/config.py @@ -23,7 +23,23 @@ class Settings(BaseSettings): sync_kg_retrieval_url: str = "https://strider.renci.org/query" default_data_tier: int = 0 omnicorp_url: str = "https://aragorn-ranker.renci.org/omnicorp_overlay" + + # ARAX configs arax_url: str = "https://arax.ncats.io/shepherd/api/arax/v1.4/query" + plover_url: str = "https://kg2cploverdb.ci.transltr.io" + curie_ngd_addr: str = ( + "mysql:arax-databases-mysql.rtx.ai:public_ro:curie_ngd_v1_0_kg2_10_2" + ) + node_degree_addr: str = ( + "mysql:arax-databases-mysql.rtx.ai:public_ro:kg2c_v1_0_kg2_10_2" + ) + arax_biolink_version: str = "4.2.5" + arax_blocked_list_url: str = ( + "https://raw.githubusercontent.com/RTXteam/RTX/master/" + "code/ARAX/KnowledgeSources/general_concepts.json" + ) + # End of ARAX configs + node_norm: str = "https://biothings.ci.transltr.io/nodenorm/api/" pathfinder_redis_host: str = "host.docker.internal" diff --git a/workers/arax/inject_shepherd_arax_provenance.py b/shepherd_utils/inject_shepherd_arax_provenance.py similarity index 100% rename from workers/arax/inject_shepherd_arax_provenance.py rename to shepherd_utils/inject_shepherd_arax_provenance.py diff --git a/workers/arax/worker.py b/workers/arax/worker.py index 53909db..1ff858d 100644 --- a/workers/arax/worker.py +++ b/workers/arax/worker.py @@ -3,14 +3,18 @@ import asyncio import json import logging -import requests import time import uuid + +import requests + from shepherd_utils.config import settings from shepherd_utils.db import get_message, save_message -from shepherd_utils.shared import get_tasks, handle_task_failure, wrap_up_task +from shepherd_utils.inject_shepherd_arax_provenance import ( + add_shepherd_arax_to_edge_sources, +) from shepherd_utils.otel import setup_tracer -from inject_shepherd_arax_provenance import add_shepherd_arax_to_edge_sources +from shepherd_utils.shared import get_tasks, handle_task_failure, wrap_up_task # Queue name STREAM = "arax" @@ -20,30 +24,50 @@ tracer = setup_tracer(STREAM) -async def arax(task, logger: logging.Logger): +def is_pathfinder_query(message): try: - query_id = task[1]["query_id"] - logger.info(f"Getting message from db for query id {query_id}") - message = await get_message(query_id, logger) - message["submitter"] = "Shepherd" - logger.info(f"Get the message from db {message}") - - headers = {"Content-Type": "application/json"} - response = requests.post(settings.arax_url, json=message, headers=headers) - - logger.info(f"Status Code from ARAX response: {response.status_code}") - result = response.json() - result = add_shepherd_arax_to_edge_sources(result) - - except Exception as e: - logger.error(f"Error occurred in ARAX entry module: {e}") - result = {"status": "error", "error": str(e)} + # this can still fail if the input looks like e.g.: + # "query_graph": None + qedges = message.get("message", {}).get("query_graph", {}).get("edges", {}) + except: + qedges = {} + try: + # this can still fail if the input looks like e.g.: + # "query_graph": None + qpaths = message.get("message", {}).get("query_graph", {}).get("paths", {}) + except: + qpaths = {} + if len(qpaths) > 1: + raise Exception("Only a single path is supported", 400) + if (len(qpaths) > 0) and (len(qedges) > 0): + raise Exception("Mixed mode pathfinder queries are not supported", 400) + return len(qpaths) == 1 - response_id = task[1]["response_id"] - await save_message(response_id, result, logger) +async def arax(task, logger: logging.Logger): + start = time.time() + query_id = task[1]["query_id"] + logger.info(f"Getting message from db for query id {query_id}") + message = await get_message(query_id, logger) + if is_pathfinder_query(message): + task[1]["workflow"] = json.dumps([{"id": "arax.pathfinder"}]) + else: + try: + message["submitter"] = "Shepherd" + logger.info(f"Get the message from db {message}") + headers = {"Content-Type": "application/json"} + response = requests.post(settings.arax_url, json=message, headers=headers) + logger.info(f"Status Code from ARAX response: {response.status_code}") + result = response.json() + result = add_shepherd_arax_to_edge_sources(result) + except Exception as e: + logger.error(f"Error occurred calling ARAX service: {e}") + result = {"status": "error", "error": str(e)} + response_id = task[1]["response_id"] + await save_message(response_id, result, logger) + task[1]["workflow"] = json.dumps([{"id": "arax"}]) - task[1]["workflow"] = json.dumps([{"id": "arax"}]) + logger.info(f"Finished task {task[0]} in {time.time() - start}") async def process_task(task, parent_ctx, logger: logging.Logger, limiter): diff --git a/workers/arax_pathfinder/Dockerfile b/workers/arax_pathfinder/Dockerfile new file mode 100644 index 0000000..890204f --- /dev/null +++ b/workers/arax_pathfinder/Dockerfile @@ -0,0 +1,34 @@ +# Use RENCI python base image +FROM ghcr.io/translatorsri/renci-python-image:3.11.5 + +# Add image info +LABEL org.opencontainers.image.source https://github.com/BioPack-team/shepherd + +ENV PYTHONHASHSEED=0 + +# set up requirements +WORKDIR /app + +# make sure all is writeable for the nru USER later on +RUN chmod -R 777 . + +# Install requirements +COPY ./shepherd_utils ./shepherd_utils +COPY ./pyproject.toml . +RUN pip install . + +COPY ./workers/arax_pathfinder/requirements.txt . +RUN pip install -r requirements.txt + +# switch to the non-root user (nru). defined in the base image +USER nru + +# Copy in files +COPY ./workers/arax_pathfinder ./ + +# Set up base for command and any variables +# that shouldn't be modified +# ENTRYPOINT ["uvicorn", "shepherd_server.server:APP"] + +# Variables that can be overriden +CMD ["python", "worker.py"] diff --git a/workers/arax_pathfinder/__init__.py b/workers/arax_pathfinder/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/workers/arax_pathfinder/requirements.txt b/workers/arax_pathfinder/requirements.txt new file mode 100644 index 0000000..5f5123c --- /dev/null +++ b/workers/arax_pathfinder/requirements.txt @@ -0,0 +1,2 @@ +catrax-pathfinder==1.2.2 +biolink-helper-pkg==1.0.0 \ No newline at end of file diff --git a/workers/arax_pathfinder/worker.py b/workers/arax_pathfinder/worker.py new file mode 100644 index 0000000..2c1014d --- /dev/null +++ b/workers/arax_pathfinder/worker.py @@ -0,0 +1,236 @@ +"""Arax ARA Pathfinder module.""" + +import asyncio +import json +import logging +import time +import uuid +from pathlib import Path + +import requests +from biolink_helper_pkg import BiolinkHelper +from pathfinder.Pathfinder import Pathfinder + +from shepherd_utils.config import settings +from shepherd_utils.db import ( + get_message, + save_message, +) +from shepherd_utils.inject_shepherd_arax_provenance import ( + add_shepherd_arax_to_edge_sources, +) +from shepherd_utils.otel import setup_tracer +from shepherd_utils.shared import ( + get_tasks, + handle_task_failure, + wrap_up_task, +) + +# Queue name +STREAM = "arax.pathfinder" +# Consumer group, most likely you don't need to change this. +GROUP = "consumer" +CONSUMER = str(uuid.uuid4())[:8] +TASK_LIMIT = 100 +tracer = setup_tracer(STREAM) + +NUM_TOTAL_HOPS = 4 +MAX_HOPS_TO_EXPLORE = 4 +MAX_PATHFINDER_PATHS = 500 +PRUNE_TOP_K = 100 +NODE_DEGREE_THRESHOLD = 1000000 + +OUT_PATH = Path("general_concepts.json") + + +def download_file(url: str, out_path: Path, overwrite: bool = False) -> Path: + out_path = Path(out_path) + + if out_path.exists() and not overwrite: + return out_path + + out_path.parent.mkdir(parents=True, exist_ok=True) + + r = requests.get(url, timeout=60) + r.raise_for_status() + + out_path.write_bytes(r.content) + return out_path + + +def get_blocked_list(): + download_file(settings.arax_blocked_list_url, OUT_PATH, False) + + with open(OUT_PATH, "r") as file: + json_block_list = json.load(file) + synonyms = set(s.lower() for s in json_block_list["synonyms"]) + return set(json_block_list["curies"]), synonyms + + +def execute_pathfinding_sync( + pinned_node_ids, pinned_node_keys, intermediate_categories, logger +): + + blocked_curies, blocked_synonyms = get_blocked_list() + + pathfinder_instance = Pathfinder( + "MLRepo", + settings.plover_url, + settings.curie_ngd_addr, + settings.node_degree_addr, + blocked_curies, + blocked_synonyms, + logger, + ) + + biolink_cache_dir = "/tmp/biolink" + Path(biolink_cache_dir).mkdir(parents=True, exist_ok=True) + biolink_helper = BiolinkHelper(settings.arax_biolink_version, biolink_cache_dir) + descendants = set(biolink_helper.get_descendants(intermediate_categories[0])) + + start = time.perf_counter() + logger.info("Starting pathfinder.get_paths() in worker thread") + + result, aux_graphs, knowledge_graph = pathfinder_instance.get_paths( + pinned_node_ids[0], + pinned_node_ids[1], + pinned_node_keys[0], + pinned_node_keys[1], + NUM_TOTAL_HOPS, + MAX_HOPS_TO_EXPLORE, + MAX_PATHFINDER_PATHS, + PRUNE_TOP_K, + NODE_DEGREE_THRESHOLD, + descendants, + ) + + elapsed = time.perf_counter() - start + logger.info(f"pathfinder.get_paths() finished in {elapsed:.3f} seconds") + + return result, aux_graphs, knowledge_graph + + +async def pathfinder(task, logger: logging.Logger): + start = time.time() + query_id = task[1]["query_id"] + response_id = task[1]["response_id"] + message = await get_message(query_id, logger) + parameters = message.get("parameters") or {} + parameters["timeout"] = parameters.get("timeout", settings.lookup_timeout) + parameters["tiers"] = parameters.get("tiers") or [0] + message["parameters"] = parameters + + qgraph = message["message"]["query_graph"] + pinned_node_keys = [] + pinned_node_ids = [] + for node_key, node in qgraph["nodes"].items(): + pinned_node_keys.append(node_key) + if node.get("ids", None) is not None: + pinned_node_ids.append(node["ids"][0]) + if len(set(pinned_node_ids)) != 2: + logger.error("Pathfinder queries require two pinned nodes.") + return message, 500 + + intermediate_categories = [] + path_key = next(iter(qgraph["paths"].keys())) + qpath = qgraph["paths"][path_key] + if ( + qpath.get("constraints", None) is not None + and len(qpath.get("constraints", [])) > 0 + ): + constraints = qpath["constraints"] + if len(constraints) > 1: + logger.error("Pathfinder queries do not support multiple constraints.") + return message, 500 + if len(constraints) > 0: + intermediate_categories = ( + constraints[0].get("intermediate_categories", None) or [] + ) + if len(intermediate_categories) > 1: + logger.error( + "Pathfinder queries do not support multiple intermediate categories" + ) + return message, 500 + else: + intermediate_categories = ["biolink:NamedThing"] + + try: + result, aux_graphs, knowledge_graph = await asyncio.to_thread( + execute_pathfinding_sync, + pinned_node_ids, + pinned_node_keys, + intermediate_categories, + logger, + ) + + res = [] + if result is not None: + res.append( + { + "id": result["id"], + "analyses": result["analyses"], + "node_bindings": result["node_bindings"], + "essence": "result", + } + ) + if aux_graphs is None: + aux_graphs = {} + if knowledge_graph is None: + knowledge_graph = {} + message["message"]["knowledge_graph"] = knowledge_graph + message["message"]["auxiliary_graphs"] = aux_graphs + message["message"]["results"] = res + + message = add_shepherd_arax_to_edge_sources(message) + + await save_message(response_id, message, logger) + except Exception as e: + logger.error( + f"PathFinder failed to find paths between {pinned_node_keys[0]} and {pinned_node_keys[1]}. " + f"Error message is: {e}" + ) + message = {"status": "error", "error": str(e)} + await save_message(response_id, message, logger) + + logger.info(f"Task took {time.time() - start}") + + +async def process_task(task, parent_ctx, logger: logging.Logger, limiter): + """Process a given task and ACK in redis.""" + start = time.time() + span = tracer.start_span(STREAM, context=parent_ctx) + try: + await pathfinder(task, logger) + # Always wrap up the task to ACK it in the broker + try: + await wrap_up_task(STREAM, GROUP, task, logger) + except Exception as e: + logger.error(f"Task {task[0]}: Failed to wrap up task: {e}") + except asyncio.CancelledError: + logger.warning(f"Task {task[0]} was cancelled") + except Exception as e: + logger.error(f"Task {task[0]} failed with unhandled error: {e}", exc_info=True) + await handle_task_failure(STREAM, GROUP, task, logger) + finally: + span.end() + limiter.release() + logger.info(f"Finished task {task[0]} in {time.time() - start}") + + +async def poll_for_tasks(): + """On initialization, poll indefinitely for available tasks.""" + while True: + try: + async for task, parent_ctx, logger, limiter in get_tasks( + STREAM, GROUP, CONSUMER, TASK_LIMIT + ): + asyncio.create_task(process_task(task, parent_ctx, logger, limiter)) + except asyncio.CancelledError: + logging.info("Poll loop cancelled, shutting down.") + except Exception as e: + logging.error(f"Error in task polling loop: {e}", exc_info=True) + await asyncio.sleep(5) # back off before retrying + + +if __name__ == "__main__": + asyncio.run(poll_for_tasks())