From d832d9fb02f57fca8a479dd13cb5cf0f70cbfcf6 Mon Sep 17 00:00:00 2001 From: Max Wang Date: Wed, 4 Mar 2026 11:35:09 -0500 Subject: [PATCH 1/4] Add more robust error handling to all workers --- shepherd_utils/db.py | 17 ++- shepherd_utils/shared.py | 27 ++++- workers/aragorn/worker.py | 58 ++++++---- workers/aragorn_lookup/worker.py | 59 +++++++--- workers/aragorn_omnicorp/worker.py | 46 +++++--- workers/aragorn_pathfinder/worker.py | 94 ++++++++------- workers/aragorn_score/worker.py | 73 +++++++----- workers/arax/worker.py | 41 ++++--- workers/arax_rank/worker.py | 85 ++++++++------ workers/bte/worker.py | 56 +++++---- workers/bte_lookup/worker.py | 64 +++++++---- workers/example_ara/worker.py | 66 ++++++----- workers/example_lookup/worker.py | 81 +++++++------ workers/example_score/worker.py | 40 +++++-- workers/filter_analyses_top_n/worker.py | 47 +++++--- workers/filter_kgraph_orphans/worker.py | 37 ++++-- workers/filter_results_top_n/worker.py | 45 +++++--- workers/finish_query/worker.py | 59 +++++++--- workers/gandalf/worker.py | 146 +++++++++++++----------- workers/gandalf_rehydrate/worker.py | 92 ++++++++------- workers/merge_message/worker.py | 132 +++++++++++---------- workers/sipr/worker.py | 46 +++++--- workers/sort_results_score/worker.py | 76 ++++++------ 23 files changed, 907 insertions(+), 580 deletions(-) diff --git a/shepherd_utils/db.py b/shepherd_utils/db.py index a4e1149..c0ed9be 100644 --- a/shepherd_utils/db.py +++ b/shepherd_utils/db.py @@ -184,16 +184,14 @@ async def get_message( """Get the message from db.""" message = {} start = time.time() - try: - # print(f"Putting {query_id} on {ara_target} stream") - message = await data_db_client.get(message_id) - if message is not None: - start_decomp = time.time() - message = orjson.loads(zstandard.decompress(message)) - logger.debug(f"Decompression took {time.time() - start_decomp}") - except Exception as e: + message = await data_db_client.get(message_id) + if message is None: # failed to get message from db - logger.error(f"Failed to get {message_id} from db: {e}") + raise Exception(f"Failed to get {message_id} from db: {e}") + + start_decomp = time.time() + message = orjson.loads(zstandard.decompress(message)) + logger.debug(f"Decompression took {time.time() - start_decomp}") logger.debug(f"Getting message took {time.time() - start} seconds") return message @@ -341,6 +339,7 @@ async def get_running_callbacks( continue except Exception as e: logger.error(f"Failed to get running lookups: {e}") + raise return running_lookups diff --git a/shepherd_utils/shared.py b/shepherd_utils/shared.py index 1c11b90..65e6ee4 100644 --- a/shepherd_utils/shared.py +++ b/shepherd_utils/shared.py @@ -75,10 +75,11 @@ async def wrap_up_task( stream: str, group: str, task: Tuple[str, dict], - workflow: List[dict], logger: logging.Logger, ): """Call the next task and mark this one as complete.""" + workflow = json.loads(task[1]["workflow"]) + logger.info(workflow) # remove the operation we just did if stream == workflow[0]["id"]: # make sure the worker is in the workflow @@ -106,6 +107,30 @@ async def wrap_up_task( await save_logs(task[1]["response_id"], logger) +async def handle_task_failure( + stream: str, + group: str, + task: Tuple[str, dict], + logger: logging.Logger, +) -> None: + """Handle any full query failures.""" + await mark_task_as_complete(stream, group, task[0], logger) + await save_logs(task[1]["response_id"], logger) + logger.error("Sending task straight to finish_query.") + await add_task( + "finish_query", + { + "query_id": task[1]["query_id"], + "response_id": task[1]["response_id"], + "workflow": "[]", + "log_level": task[1].get("log_level", 20), + "otel": task[1]["otel"], + "status": "ERROR", + }, + logger, + ) + + def recursive_get_edge_support_graphs( edge: str, edges: set, diff --git a/workers/aragorn/worker.py b/workers/aragorn/worker.py index 3b60502..c1c317e 100644 --- a/workers/aragorn/worker.py +++ b/workers/aragorn/worker.py @@ -8,7 +8,7 @@ from shepherd_utils.db import get_message from shepherd_utils.otel import setup_tracer -from shepherd_utils.shared import get_tasks, wrap_up_task +from shepherd_utils.shared import get_tasks, handle_task_failure, wrap_up_task # Queue name STREAM = "aragorn" @@ -31,27 +31,27 @@ def examine_query(message): # this can still fail if the input looks like e.g.: # "query_graph": None qedges = message.get("message", {}).get("query_graph", {}).get("edges", {}) - except: + except KeyError: qedges = {} try: # this can still fail if the input looks like e.g.: # "query_graph": None qpaths = message.get("message", {}).get("query_graph", {}).get("paths", {}) - except: + except KeyError: qpaths = {} if len(qpaths) > 1: - raise Exception("Only a single path is supported", 400) + raise Exception("Only a single path is supported") if (len(qpaths) > 0) and (len(qedges) > 0): - raise Exception("Mixed mode pathfinder queries are not supported", 400) + raise Exception("Mixed mode pathfinder queries are not supported") pathfinder = len(qpaths) == 1 n_infer_edges = 0 for edge_id in qedges: if qedges.get(edge_id, {}).get("knowledge_type", "lookup") == "inferred": n_infer_edges += 1 if n_infer_edges > 1 and n_infer_edges: - raise Exception("Only a single infer edge is supported", 400) + raise Exception("Only a single infer edge is supported") if (n_infer_edges > 0) and (n_infer_edges < len(qedges)): - raise Exception("Mixed infer and lookup queries not supported", 400) + raise Exception("Mixed infer and lookup queries not supported") infer = n_infer_edges == 1 if not infer: return infer, None, None, pathfinder @@ -64,23 +64,19 @@ def examine_query(message): else: question_node = qnode_id if answer_node is None: - raise Exception("Both nodes of creative edge pinned", 400) + raise Exception("Both nodes of creative edge pinned") if question_node is None: - raise Exception("No nodes of creative edge pinned", 400) + raise Exception("No nodes of creative edge pinned") return infer, question_node, answer_node, pathfinder async def aragorn(task, logger: logging.Logger): - start = time.time() + """Examine and define the Aragorn workflow.""" # given a task, get the message from the db query_id = task[1]["query_id"] workflow = json.loads(task[1]["workflow"]) message = await get_message(query_id, logger) - try: - infer, question_qnode, answer_qnode, pathfinder = examine_query(message) - except Exception as e: - logger.error(e) - return None, 500 + infer, question_qnode, answer_qnode, pathfinder = examine_query(message) if workflow is None: if infer: @@ -113,24 +109,44 @@ async def aragorn(task, logger: logging.Logger): {"id": "filter_kgraph_orphans"}, ] - await wrap_up_task(STREAM, GROUP, task, workflow, logger) - logger.info(f"Task took {time.time() - start}") + task[1]["workflow"] = json.dumps(workflow) async def process_task(task, parent_ctx, logger, limiter): + """Process a given task and ACK in redis.""" + start = time.time() span = tracer.start_span(STREAM, context=parent_ctx) try: await aragorn(task, logger) + try: + logger.info(task) + await wrap_up_task(STREAM, GROUP, task, logger) + except Exception as e: + logger.error(f"Task {task[0]}: Failed to wrap up task: {e}") + except asyncio.CancelledError: + logger.warning(f"Task {task[0]} was cancelled.") + except Exception as e: + logger.error(f"Task {task[0]} failed with unhandled error: {e}", exc_info=True) + await handle_task_failure(STREAM, GROUP, task, logger) finally: span.end() limiter.release() + logger.info(f"Task took {time.time() - start}") async def poll_for_tasks(): - async for task, parent_ctx, logger, limiter in get_tasks( - STREAM, GROUP, CONSUMER, TASK_LIMIT - ): - asyncio.create_task(process_task(task, parent_ctx, logger, limiter)) + """On initialization, poll indefinitely for available tasks.""" + while True: + try: + async for task, parent_ctx, logger, limiter in get_tasks( + STREAM, GROUP, CONSUMER, TASK_LIMIT + ): + asyncio.create_task(process_task(task, parent_ctx, logger, limiter)) + except asyncio.CancelledError: + logging.info("Poll loop cancelled, shutting down.") + except Exception as e: + logging.error(f"Error in task polling loop: {e}", exc_info=True) + await asyncio.sleep(5) # back off before retrying if __name__ == "__main__": diff --git a/workers/aragorn_lookup/worker.py b/workers/aragorn_lookup/worker.py index 56e7c38..3ac7374 100644 --- a/workers/aragorn_lookup/worker.py +++ b/workers/aragorn_lookup/worker.py @@ -23,7 +23,7 @@ save_message, ) from shepherd_utils.otel import setup_tracer -from shepherd_utils.shared import add_task, get_tasks, wrap_up_task +from shepherd_utils.shared import add_task, get_tasks, handle_task_failure, wrap_up_task # Queue name STREAM = "aragorn.lookup" @@ -46,7 +46,7 @@ def examine_query(message): # this can still fail if the input looks like e.g.: # "query_graph": None qedges = message.get("message", {}).get("query_graph", {}).get("edges", {}) - except: + except KeyError: qedges = {} n_infer_edges = 0 for edge_id in qedges: @@ -54,9 +54,9 @@ def examine_query(message): n_infer_edges += 1 pathfinder = n_infer_edges == 3 if n_infer_edges > 1 and n_infer_edges and not pathfinder: - raise Exception("Only a single infer edge is supported", 400) + raise Exception("Only a single infer edge is supported") if (n_infer_edges > 0) and (n_infer_edges < len(qedges)): - raise Exception("Mixed infer and lookup queries not supported", 400) + raise Exception("Mixed infer and lookup queries not supported") infer = n_infer_edges == 1 if not infer: return infer, None, None, pathfinder @@ -69,9 +69,9 @@ def examine_query(message): else: question_node = qnode_id if answer_node is None: - raise Exception("Both nodes of creative edge pinned", 400) + raise Exception("Both nodes of creative edge pinned") if question_node is None: - raise Exception("No nodes of creative edge pinned", 400) + raise Exception("No nodes of creative edge pinned") return infer, question_node, answer_node, pathfinder @@ -109,11 +109,10 @@ async def run_async_lookup( async def aragorn_lookup(task, logger: logging.Logger): - start = time.time() + """Do Aragorn lookup operation.""" # given a task, get the message from the db query_id = task[1]["query_id"] response_id = task[1]["response_id"] - workflow = json.loads(task[1]["workflow"]) message = await get_message(query_id, logger) parameters = message.get("parameters") or {} parameters["timeout"] = parameters.get("timeout", settings.lookup_timeout) @@ -241,9 +240,13 @@ async def aragorn_lookup(task, logger: logging.Logger): start_time = time.time() running_callback_ids = [""] while time.time() - start_time < MAX_QUERY_TIME: - # see if there are existing lookups going - running_callback_ids = await get_running_callbacks(query_id, logger) - # logger.info(f"Got back {len(running_callback_ids)} running lookups") + try: + # see if there are existing lookups going + running_callback_ids = await get_running_callbacks(query_id, logger) + except Exception: + # Brief backoff then retry the check rather than giving up + await asyncio.sleep(5) + continue # if there are, continue to wait if len(running_callback_ids) > 0: await asyncio.sleep(1) @@ -260,9 +263,6 @@ async def aragorn_lookup(task, logger: logging.Logger): # logger.warning(f"Running callbacks: {running_callback_ids}") await cleanup_callbacks(query_id, logger) - await wrap_up_task(STREAM, GROUP, task, workflow, logger) - logger.info(f"Finished task {task[0]} in {time.time() - start}") - def get_infer_parameters(input_message): """Given an infer input message, return the parameters needed to run the infer. @@ -386,20 +386,41 @@ def expand_aragorn_query(input_message, logger: logging.Logger): return messages -async def process_task(task, parent_ctx, logger, limiter): +async def process_task(task, parent_ctx, logger: logging.Logger, limiter): + """Process a given task and ACK in redis.""" + start = time.time() span = tracer.start_span(STREAM, context=parent_ctx) try: await aragorn_lookup(task, logger) + # Always wrap up the task to ACK it in the broker + try: + await wrap_up_task(STREAM, GROUP, task, logger) + except Exception as e: + logger.error(f"Task {task[0]}: Failed to wrap up task: {e}") + except asyncio.CancelledError: + logger.warning(f"Task {task[0]} was cancelled") + except Exception as e: + logger.error(f"Task {task[0]} failed with unhandled error: {e}", exc_info=True) + await handle_task_failure(STREAM, GROUP, task, logger) finally: span.end() limiter.release() + logger.info(f"Finished task {task[0]} in {time.time() - start}") async def poll_for_tasks(): - async for task, parent_ctx, logger, limiter in get_tasks( - STREAM, GROUP, CONSUMER, TASK_LIMIT - ): - asyncio.create_task(process_task(task, parent_ctx, logger, limiter)) + """On initialization, poll indefinitely for available tasks.""" + while True: + try: + async for task, parent_ctx, logger, limiter in get_tasks( + STREAM, GROUP, CONSUMER, TASK_LIMIT + ): + asyncio.create_task(process_task(task, parent_ctx, logger, limiter)) + except asyncio.CancelledError: + logging.info("Poll loop cancelled, shutting down.") + except Exception as e: + logging.error(f"Error in task polling loop: {e}", exc_info=True) + await asyncio.sleep(5) # back off before retrying if __name__ == "__main__": diff --git a/workers/aragorn_omnicorp/worker.py b/workers/aragorn_omnicorp/worker.py index b2ae04e..9a46036 100644 --- a/workers/aragorn_omnicorp/worker.py +++ b/workers/aragorn_omnicorp/worker.py @@ -7,9 +7,9 @@ import time import uuid from shepherd_utils.config import settings -from shepherd_utils.db import get_message +from shepherd_utils.db import get_message, save_message from shepherd_utils.otel import setup_tracer -from shepherd_utils.shared import get_tasks, wrap_up_task +from shepherd_utils.shared import get_tasks, handle_task_failure, wrap_up_task # Queue name STREAM = "aragorn.omnicorp" @@ -21,36 +21,56 @@ async def aragorn_omnicorp(task, logger: logging.Logger): - start = time.time() # given a task, get the message from the db response_id = task[1]["response_id"] - workflow = json.loads(task[1]["workflow"]) message = await get_message(response_id, logger) - async with httpx.AsyncClient(timeout=100) as client: - await client.post( + async with httpx.AsyncClient(timeout=120) as client: + response = await client.post( settings.omnicorp_url, json=message, ) + response.raise_for_status() - await wrap_up_task(STREAM, GROUP, task, workflow, logger) - logger.info(f"Task took {time.time() - start}") + response = response.json() + await save_message(response_id, response, logger) -async def process_task(task, parent_ctx, logger, limiter): +async def process_task(task, parent_ctx, logger: logging.Logger, limiter): + """Process a given task and ACK in redis.""" + start = time.time() span = tracer.start_span(STREAM, context=parent_ctx) try: await aragorn_omnicorp(task, logger) + # Always wrap up the task to ACK it in the broker + try: + await wrap_up_task(STREAM, GROUP, task, logger) + except Exception as e: + logger.error(f"Task {task[0]}: Failed to wrap up task: {e}") + except asyncio.CancelledError: + logger.warning(f"Task {task[0]} was cancelled") + except Exception as e: + logger.error(f"Task {task[0]} failed with unhandled error: {e}", exc_info=True) + await handle_task_failure(STREAM, GROUP, task, logger) finally: span.end() limiter.release() + logger.info(f"Finished task {task[0]} in {time.time() - start}") async def poll_for_tasks(): - async for task, parent_ctx, logger, limiter in get_tasks( - STREAM, GROUP, CONSUMER, TASK_LIMIT - ): - asyncio.create_task(process_task(task, parent_ctx, logger, limiter)) + """On initialization, poll indefinitely for available tasks.""" + while True: + try: + async for task, parent_ctx, logger, limiter in get_tasks( + STREAM, GROUP, CONSUMER, TASK_LIMIT + ): + asyncio.create_task(process_task(task, parent_ctx, logger, limiter)) + except asyncio.CancelledError: + logging.info("Poll loop cancelled, shutting down.") + except Exception as e: + logging.error(f"Error in task polling loop: {e}", exc_info=True) + await asyncio.sleep(5) # back off before retrying if __name__ == "__main__": diff --git a/workers/aragorn_pathfinder/worker.py b/workers/aragorn_pathfinder/worker.py index 9afe88c..2c2be4a 100644 --- a/workers/aragorn_pathfinder/worker.py +++ b/workers/aragorn_pathfinder/worker.py @@ -18,6 +18,7 @@ from shepherd_utils.shared import ( add_task, get_tasks, + handle_task_failure, wrap_up_task, ) @@ -35,10 +36,8 @@ async def shadowfax(task, logger: logging.Logger): co-occurrence to find nodes that occur in publications with our input nodes, then finding paths that connect our input nodes through these intermediate nodes.""" - start = time.time() # given a task, get the message from the db query_id = task[1]["query_id"] - workflow = json.loads(task[1]["workflow"]) response_id = task[1]["response_id"] message = await get_message(query_id, logger) parameters = message.get("parameters") or {} @@ -55,9 +54,7 @@ async def shadowfax(task, logger: logging.Logger): # TODO: silently only grabbing the first id pinned_node_ids.append(node["ids"][0]) if len(set(pinned_node_ids)) != 2: - logger.error("Pathfinder queries require two pinned nodes.") - # TODO: Update to wrap up task - return message, 500 + raise Exception("Pathfinder queries require two pinned nodes.") intermediate_categories = [] path_key = next(iter(qgraph["paths"].keys())) @@ -66,17 +63,15 @@ async def shadowfax(task, logger: logging.Logger): constraints = qpath["constraints"] # TODO: need to wrap up tasks if len(constraints) > 1: - logger.error("Pathfinder queries do not support multiple constraints.") - return message, 500 + raise Exception("Pathfinder queries do not support multiple constraints.") if len(constraints) > 0: intermediate_categories = ( constraints[0].get("intermediate_categories", None) or [] ) if len(intermediate_categories) > 1: - logger.error( + raise Exception( "Pathfinder queries do not support multiple intermediate categories" ) - return message, 500 else: intermediate_categories = ["biolink:NamedThing"] @@ -184,9 +179,9 @@ async def shadowfax(task, logger: logging.Logger): callback_id = str(uuid.uuid4())[:8] # Put callback UID and query ID in postgres await add_callback_id(query_id, callback_id, logger) - logger.debug("""Sending pathfinder lookup query to gandalf.""") await save_message(callback_id, threehop, logger) + logger.debug("""Sending pathfinder lookup query to gandalf.""") await add_task( "gandalf", @@ -206,44 +201,65 @@ async def shadowfax(task, logger: logging.Logger): MAX_QUERY_TIME = message["parameters"]["timeout"] start_time = time.time() running_callback_ids = [""] - while time.time() - start_time < MAX_QUERY_TIME: - # see if there are existing lookups going - running_callback_ids = await get_running_callbacks(query_id, logger) - # logger.info(f"Got back {len(running_callback_ids)} running lookups") - # if there are, continue to wait - if len(running_callback_ids) > 0: - await asyncio.sleep(1) - continue - # if there aren't, lookup is complete and we need to pass on to next workflow operation - if len(running_callback_ids) == 0: - logger.debug("Got all lookups back. Continuing...") - break - - if time.time() - start_time > MAX_QUERY_TIME: - logger.warning( - f"Timed out getting lookup callbacks. {len(running_callback_ids)} queries were still running..." - ) - logger.warning(f"Running callbacks: {running_callback_ids}") - await cleanup_callbacks(query_id, logger) - - await wrap_up_task(STREAM, GROUP, task, workflow, logger) - logger.info(f"Task took {time.time() - start}") - - -async def process_task(task, parent_ctx, logger, limiter): + try: + while time.time() - start_time < MAX_QUERY_TIME: + # see if there are existing lookups going + running_callback_ids = await get_running_callbacks(query_id, logger) + # logger.info(f"Got back {len(running_callback_ids)} running lookups") + # if there are, continue to wait + if len(running_callback_ids) > 0: + await asyncio.sleep(1) + continue + # if there aren't, lookup is complete and we need to pass on to next workflow operation + if len(running_callback_ids) == 0: + logger.debug("Got all lookups back. Continuing...") + break + + if time.time() - start_time > MAX_QUERY_TIME: + logger.warning( + f"Timed out getting lookup callbacks. {len(running_callback_ids)} queries were still running..." + ) + logger.warning(f"Running callbacks: {running_callback_ids}") + await cleanup_callbacks(query_id, logger) + except asyncio.CancelledError: + logger.warning(f"Task {task[0]}: Cancelled while waiting for callbacks") + + +async def process_task(task, parent_ctx, logger: logging.Logger, limiter): + """Process a given task and ACK in redis.""" + start = time.time() span = tracer.start_span(STREAM, context=parent_ctx) try: await shadowfax(task, logger) + # Always wrap up the task to ACK it in the broker + try: + await wrap_up_task(STREAM, GROUP, task, logger) + except Exception as e: + logger.error(f"Task {task[0]}: Failed to wrap up task: {e}") + except asyncio.CancelledError: + logger.warning(f"Task {task[0]} was cancelled") + except Exception as e: + logger.error(f"Task {task[0]} failed with unhandled error: {e}", exc_info=True) + await handle_task_failure(STREAM, GROUP, task, logger) finally: span.end() limiter.release() + logger.info(f"Finished task {task[0]} in {time.time() - start}") async def poll_for_tasks(): - async for task, parent_ctx, logger, limiter in get_tasks( - STREAM, GROUP, CONSUMER, TASK_LIMIT - ): - asyncio.create_task(process_task(task, parent_ctx, logger, limiter)) + """On initialization, poll indefinitely for available tasks.""" + while True: + try: + async for task, parent_ctx, logger, limiter in get_tasks( + STREAM, GROUP, CONSUMER, TASK_LIMIT + ): + asyncio.create_task(process_task(task, parent_ctx, logger, limiter)) + except asyncio.CancelledError: + logging.info("Poll loop cancelled, shutting down.") + except Exception as e: + logging.error(f"Error in task polling loop: {e}", exc_info=True) + await asyncio.sleep(5) # back off before retrying if __name__ == "__main__": diff --git a/workers/aragorn_score/worker.py b/workers/aragorn_score/worker.py index dbe492a..b12ebc3 100644 --- a/workers/aragorn_score/worker.py +++ b/workers/aragorn_score/worker.py @@ -16,7 +16,7 @@ from shepherd_utils.db import get_message, save_message from shepherd_utils.otel import setup_tracer -from shepherd_utils.shared import get_tasks, wrap_up_task +from shepherd_utils.shared import get_tasks, handle_task_failure, wrap_up_task # Queue name STREAM = "aragorn.score" @@ -1218,38 +1218,55 @@ def aragorn_score(in_message, logger: logging.Logger): async def poll_for_tasks(): + """On initialization, poll indefinitely for available tasks.""" loop = asyncio.get_running_loop() cpu_count = os.cpu_count() cpu_count = cpu_count if cpu_count is not None else 1 cpu_count = min(cpu_count, TASK_LIMIT) executor = ProcessPoolExecutor(max_workers=cpu_count) - async for task, parent_ctx, logger, limiter in get_tasks( - STREAM, GROUP, CONSUMER, TASK_LIMIT - ): - span = tracer.start_span(STREAM, context=parent_ctx) - start = time.time() - # given a task, get the message from the db - response_id = task[1]["response_id"] - workflow = json.loads(task[1]["workflow"]) - message = await get_message(response_id, logger) - if message is not None: - scored_message = await loop.run_in_executor( - executor, - aragorn_score, - message, - logger, - ) - if scored_message is None: - logger.error("Failed to score message. Returning unscored.") - scored_message = message - await save_message(response_id, scored_message, logger) - else: - logger.error(f"Failed to get {response_id} for scoring.") - await wrap_up_task(STREAM, GROUP, task, workflow, logger) - - logger.info(f"Finished task {task[0]} in {time.time() - start}") - span.end() - limiter.release() + while True: + try: + async for task, parent_ctx, logger, limiter in get_tasks( + STREAM, GROUP, CONSUMER, TASK_LIMIT + ): + start = time.time() + span = tracer.start_span(STREAM, context=parent_ctx) + try: + # given a task, get the message from the db + response_id = task[1]["response_id"] + message = await get_message(response_id, logger) + if message is not None: + scored_message = await loop.run_in_executor( + executor, + aragorn_score, + message, + logger, + ) + if scored_message is None: + logger.error("Failed to score message. Returning unscored.") + scored_message = message + await save_message(response_id, scored_message, logger) + else: + logger.error(f"Failed to get {response_id} for scoring.") + # Always wrap up the task to ACK it in the broker + try: + await wrap_up_task(STREAM, GROUP, task, logger) + except Exception as e: + logger.error(f"Task {task[0]}: Failed to wrap up task: {e}") + except asyncio.CancelledError: + logger.warning(f"Task {task[0]} was cancelled") + except Exception as e: + logger.error(f"Task {task[0]} failed with unhandled error: {e}", exc_info=True) + await handle_task_failure(STREAM, GROUP, task, logger) + finally: + logger.info(f"Finished task {task[0]} in {time.time() - start}") + span.end() + limiter.release() + except asyncio.CancelledError: + logging.info("Poll loop cancelled, shutting down.") + except Exception as e: + logging.error(f"Error in task polling loop: {e}", exc_info=True) + await asyncio.sleep(5) # back off before retrying if __name__ == "__main__": diff --git a/workers/arax/worker.py b/workers/arax/worker.py index cb34435..53909db 100644 --- a/workers/arax/worker.py +++ b/workers/arax/worker.py @@ -1,13 +1,14 @@ """ARAX entry module.""" import asyncio +import json import logging import requests import time import uuid from shepherd_utils.config import settings from shepherd_utils.db import get_message, save_message -from shepherd_utils.shared import get_tasks, wrap_up_task +from shepherd_utils.shared import get_tasks, handle_task_failure, wrap_up_task from shepherd_utils.otel import setup_tracer from inject_shepherd_arax_provenance import add_shepherd_arax_to_edge_sources @@ -21,7 +22,6 @@ async def arax(task, logger: logging.Logger): try: - start = time.time() query_id = task[1]["query_id"] logger.info(f"Getting message from db for query id {query_id}") message = await get_message(query_id, logger) @@ -43,29 +43,44 @@ async def arax(task, logger: logging.Logger): await save_message(response_id, result, logger) - workflow = [{"id": "arax"}] + task[1]["workflow"] = json.dumps([{"id": "arax"}]) - await wrap_up_task(STREAM, GROUP, task, workflow, logger) - logger.info(f"Finished task {task[0]} in {time.time() - start}") - - -async def process_task(task, parent_ctx, logger, limiter): +async def process_task(task, parent_ctx, logger: logging.Logger, limiter): + """Process a given task and ACK in redis.""" + start = time.time() span = tracer.start_span(STREAM, context=parent_ctx) try: await arax(task, logger) + # Always wrap up the task to ACK it in the broker + try: + await wrap_up_task(STREAM, GROUP, task, logger) + except Exception as e: + logger.error(f"Task {task[0]}: Failed to wrap up task: {e}") + except asyncio.CancelledError: + logger.warning(f"Task {task[0]} was cancelled") except Exception as e: - logger.error(f"Something went wrong: {e}") + logger.error(f"Task {task[0]} failed with unhandled error: {e}", exc_info=True) + await handle_task_failure(STREAM, GROUP, task, logger) finally: span.end() limiter.release() + logger.info(f"Finished task {task[0]} in {time.time() - start}") async def poll_for_tasks(): - async for task, parent_ctx, logger, limiter in get_tasks( - STREAM, GROUP, CONSUMER, TASK_LIMIT - ): - asyncio.create_task(process_task(task, parent_ctx, logger, limiter)) + """On initialization, poll indefinitely for available tasks.""" + while True: + try: + async for task, parent_ctx, logger, limiter in get_tasks( + STREAM, GROUP, CONSUMER, TASK_LIMIT + ): + asyncio.create_task(process_task(task, parent_ctx, logger, limiter)) + except asyncio.CancelledError: + logging.info("Poll loop cancelled, shutting down.") + except Exception as e: + logging.error(f"Error in task polling loop: {e}", exc_info=True) + await asyncio.sleep(5) # back off before retrying if __name__ == "__main__": diff --git a/workers/arax_rank/worker.py b/workers/arax_rank/worker.py index 43d9f8f..109bda1 100644 --- a/workers/arax_rank/worker.py +++ b/workers/arax_rank/worker.py @@ -20,7 +20,7 @@ from shepherd_utils.db import get_message, save_message from shepherd_utils.otel import setup_tracer -from shepherd_utils.shared import get_tasks, wrap_up_task +from shepherd_utils.shared import get_tasks, handle_task_failure, wrap_up_task from ranker import arax_rank @@ -92,40 +92,55 @@ async def poll_for_tasks() -> None: cpu_count = min(cpu_count, TASK_LIMIT) executor = ProcessPoolExecutor(max_workers=cpu_count) - async for task, parent_ctx, logger, limiter in get_tasks( - STREAM, GROUP, CONSUMER, TASK_LIMIT - ): - span = tracer.start_span(STREAM, context=parent_ctx) - start = time.time() - - # Get task details - response_id = task[1]["response_id"] - workflow = json.loads(task[1]["workflow"]) - - # Get message from Redis - message = await get_message(response_id, logger) - - if message is not None: - # Run ranking in process pool for CPU-intensive operations - ranked_message = await loop.run_in_executor( - executor, - rank_message, - message, - logger, - ) - if ranked_message is None: - logger.error("Ranking returned None. Returning original message.") - ranked_message = message - await save_message(response_id, ranked_message, logger) - else: - logger.error(f"Failed to get {response_id} for ranking.") - - # Pass to next operation in workflow - await wrap_up_task(STREAM, GROUP, task, workflow, logger) - - logger.info(f"Finished task {task[0]} in {time.time() - start:.2f}s") - span.end() - limiter.release() + while True: + try: + async for task, parent_ctx, logger, limiter in get_tasks( + STREAM, GROUP, CONSUMER, TASK_LIMIT + ): + start = time.time() + span = tracer.start_span(STREAM, context=parent_ctx) + + try: + # Get task details + response_id = task[1]["response_id"] + + # Get message from Redis + message = await get_message(response_id, logger) + + if message is not None: + # Run ranking in process pool for CPU-intensive operations + ranked_message = await loop.run_in_executor( + executor, + rank_message, + message, + logger, + ) + if ranked_message is None: + logger.error("Ranking returned None. Returning original message.") + ranked_message = message + await save_message(response_id, ranked_message, logger) + else: + logger.error(f"Failed to get {response_id} for ranking.") + + # Always wrap up the task to ACK it in the broker + try: + await wrap_up_task(STREAM, GROUP, task, logger) + except Exception as e: + logger.error(f"Task {task[0]}: Failed to wrap up task: {e}") + except asyncio.CancelledError: + logger.warning(f"Task {task[0]} was cancelled") + except Exception as e: + logger.error(f"Task {task[0]} failed with unhandled error: {e}", exc_info=True) + await handle_task_failure(STREAM, GROUP, task, logger) + finally: + logger.info(f"Finished task {task[0]} in {time.time() - start}") + span.end() + limiter.release() + except asyncio.CancelledError: + logging.info("Poll loop cancelled, shutting down.") + except Exception as e: + logging.error(f"Error in task polling loop: {e}", exc_info=True) + await asyncio.sleep(5) # back off before retrying if __name__ == "__main__": diff --git a/workers/bte/worker.py b/workers/bte/worker.py index b5b2220..5d8e890 100644 --- a/workers/bte/worker.py +++ b/workers/bte/worker.py @@ -8,7 +8,7 @@ from shepherd_utils.db import get_message from shepherd_utils.otel import setup_tracer -from shepherd_utils.shared import get_tasks, wrap_up_task +from shepherd_utils.shared import get_tasks, handle_task_failure, wrap_up_task # Queue name STREAM = "bte" @@ -31,7 +31,7 @@ def examine_query(message): # this can still fail if the input looks like e.g.: # "query_graph": None qedges = message.get("message", {}).get("query_graph", {}).get("edges", {}) - except: + except KeyError: qedges = {} n_infer_edges = 0 for edge_id in qedges: @@ -39,9 +39,9 @@ def examine_query(message): n_infer_edges += 1 pathfinder = n_infer_edges == 3 if n_infer_edges > 1 and n_infer_edges and not pathfinder: - raise Exception("Only a single infer edge is supported", 400) + raise Exception("Only a single infer edge is supported") if (n_infer_edges > 0) and (n_infer_edges < len(qedges)): - raise Exception("Mixed infer and lookup queries not supported", 400) + raise Exception("Mixed infer and lookup queries not supported") infer = n_infer_edges == 1 if not infer: return infer, None, None, pathfinder @@ -54,27 +54,21 @@ def examine_query(message): else: question_node = qnode_id if answer_node is None: - raise Exception("Both nodes of creative edge pinned", 400) + raise Exception("Both nodes of creative edge pinned") if question_node is None: - raise Exception("No nodes of creative edge pinned", 400) + raise Exception("No nodes of creative edge pinned") return infer, question_node, answer_node, pathfinder async def bte(task, logger: logging.Logger): """Main BTE entrypoint that establishes query workflow.""" - start = time.time() # given a task, get the message from the db query_id = task[1]["query_id"] workflow = json.loads(task[1]["workflow"]) message = await get_message(query_id, logger) - try: - infer, question_qnode, answer_qnode, pathfinder = examine_query(message) - except Exception as e: - logger.error(e) - return None, 500 + infer, question_qnode, answer_qnode, pathfinder = examine_query(message) if pathfinder: - # BTE doesn't currently handle Pathfinder queries - return None, 500 + raise Exception("BTE does not support Pathfinder type queries.") if workflow is None: if infer: @@ -96,24 +90,44 @@ async def bte(task, logger: logging.Logger): {"id": "filter_kgraph_orphans"}, ] - await wrap_up_task(STREAM, GROUP, task, workflow, logger) - logger.info(f"Task took {time.time() - start}") + task[1]["workflow"] = json.dumps(workflow) -async def process_task(task, parent_ctx, logger, limiter): +async def process_task(task, parent_ctx, logger: logging.Logger, limiter): + """Process a given task and ACK in redis.""" + start = time.time() span = tracer.start_span(STREAM, context=parent_ctx) try: await bte(task, logger) + # Always wrap up the task to ACK it in the broker + try: + await wrap_up_task(STREAM, GROUP, task, logger) + except Exception as e: + logger.error(f"Task {task[0]}: Failed to wrap up task: {e}") + except asyncio.CancelledError: + logger.warning(f"Task {task[0]} was cancelled") + except Exception as e: + logger.error(f"Task {task[0]} failed with unhandled error: {e}", exc_info=True) + await handle_task_failure(STREAM, GROUP, task, logger) finally: span.end() limiter.release() + logger.info(f"Finished task {task[0]} in {time.time() - start}") async def poll_for_tasks(): - async for task, parent_ctx, logger, limiter in get_tasks( - STREAM, GROUP, CONSUMER, TASK_LIMIT - ): - asyncio.create_task(process_task(task, parent_ctx, logger, limiter)) + """On initialization, poll indefinitely for available tasks.""" + while True: + try: + async for task, parent_ctx, logger, limiter in get_tasks( + STREAM, GROUP, CONSUMER, TASK_LIMIT + ): + asyncio.create_task(process_task(task, parent_ctx, logger, limiter)) + except asyncio.CancelledError: + logging.info("Poll loop cancelled, shutting down.") + except Exception as e: + logging.error(f"Error in task polling loop: {e}", exc_info=True) + await asyncio.sleep(5) # back off before retrying if __name__ == "__main__": diff --git a/workers/bte_lookup/worker.py b/workers/bte_lookup/worker.py index 7cceab6..5ca5397 100644 --- a/workers/bte_lookup/worker.py +++ b/workers/bte_lookup/worker.py @@ -23,7 +23,7 @@ save_message, ) from shepherd_utils.otel import setup_tracer -from shepherd_utils.shared import get_tasks, wrap_up_task +from shepherd_utils.shared import get_tasks, handle_task_failure, wrap_up_task # Queue name STREAM = "bte.lookup" @@ -46,7 +46,7 @@ def examine_query(message): # this can still fail if the input looks like e.g.: # "query_graph": None qedges = message.get("message", {}).get("query_graph", {}).get("edges", {}) - except: + except KeyError: qedges = {} n_infer_edges = 0 for edge_id in qedges: @@ -54,9 +54,9 @@ def examine_query(message): n_infer_edges += 1 pathfinder = n_infer_edges == 3 if n_infer_edges > 1 and n_infer_edges and not pathfinder: - raise Exception("Only a single infer edge is supported", 400) + raise Exception("Only a single infer edge is supported") if (n_infer_edges > 0) and (n_infer_edges < len(qedges)): - raise Exception("Mixed infer and lookup queries not supported", 400) + raise Exception("Mixed infer and lookup queries not supported") infer = n_infer_edges == 1 if not infer: return infer, None, None, pathfinder @@ -69,9 +69,9 @@ def examine_query(message): else: question_node = qnode_id if answer_node is None: - raise Exception("Both nodes of creative edge pinned", 400) + raise Exception("Both nodes of creative edge pinned") if question_node is None: - raise Exception("No nodes of creative edge pinned", 400) + raise Exception("No nodes of creative edge pinned") return infer, question_node, answer_node, pathfinder @@ -109,10 +109,8 @@ async def run_async_lookup( async def bte_lookup(task, logger: logging.Logger): - start = time.time() # given a task, get the message from the db query_id = task[1]["query_id"] - workflow = json.loads(task[1]["workflow"]) message = await get_message(query_id, logger) parameters = message.get("parameters") or {} parameters["timeout"] = parameters.get("timeout", settings.lookup_timeout) @@ -126,14 +124,10 @@ async def bte_lookup(task, logger: logging.Logger): url=settings.server_url, ) ) - try: - infer, question_qnode, answer_qnode, pathfinder = examine_query(message) - except Exception as e: - logger.error(e) - return None, 500 + infer, question_qnode, answer_qnode, pathfinder = examine_query(message) if pathfinder: # BTE currently doesn't handle Pathfinder queries - return None, 500 + raise Exception("BTE does not support Pathfinder type queries.") if not infer: # Put callback UID and query ID in postgres @@ -192,9 +186,13 @@ async def bte_lookup(task, logger: logging.Logger): start_time = time.time() running_callback_ids = [""] while time.time() - start_time < MAX_QUERY_TIME: - # see if there are existing lookups going - running_callback_ids = await get_running_callbacks(query_id, logger) - # logger.info(f"Got back {len(running_callback_ids)} running lookups") + try: + # see if there are existing lookups going + running_callback_ids = await get_running_callbacks(query_id, logger) + except Exception: + # Brief backoff then retry the check rather than giving up + await asyncio.sleep(5) + continue # if there are, continue to wait if len(running_callback_ids) > 0: await asyncio.sleep(1) @@ -210,9 +208,6 @@ async def bte_lookup(task, logger: logging.Logger): ) await cleanup_callbacks(query_id, logger) - await wrap_up_task(STREAM, GROUP, task, workflow, logger) - logger.info(f"Finished task {task[0]} in {time.time() - start}") - class TemplateGroup(BaseModel): """A group of templates to be matched by given criteria.""" @@ -422,19 +417,40 @@ def expand_bte_query(query_dict: dict[str, Any], logger: logging.Logger) -> list async def process_task(task, parent_ctx, logger, limiter): + """Process a given task and ACK in redis.""" + start = time.time() span = tracer.start_span(STREAM, context=parent_ctx) try: await bte_lookup(task, logger) + try: + # Always wrap up the task to ACK it in the broker + await wrap_up_task(STREAM, GROUP, task, logger) + except Exception as e: + logger.error(f"Task {task[0]}: Failed to wrap up task: {e}") + except asyncio.CancelledError: + logger.warning(f"Task {task[0]} was cancelled.") + except Exception as e: + logger.error(f"Task {task[0]} failed with unhandled error: {e}", exc_info=True) + await handle_task_failure(STREAM, GROUP, task, logger) finally: span.end() limiter.release() + logger.info(f"Task took {time.time() - start}") async def poll_for_tasks(): - async for task, parent_ctx, logger, limiter in get_tasks( - STREAM, GROUP, CONSUMER, TASK_LIMIT - ): - asyncio.create_task(process_task(task, parent_ctx, logger, limiter)) + """On initialization, poll indefinitely for available tasks.""" + while True: + try: + async for task, parent_ctx, logger, limiter in get_tasks( + STREAM, GROUP, CONSUMER, TASK_LIMIT + ): + asyncio.create_task(process_task(task, parent_ctx, logger, limiter)) + except asyncio.CancelledError: + logging.info("Poll loop cancelled, shutting down.") + except Exception as e: + logging.error(f"Error in task polling loop: {e}", exc_info=True) + await asyncio.sleep(5) # back off before retrying if __name__ == "__main__": diff --git a/workers/example_ara/worker.py b/workers/example_ara/worker.py index b49be17..03afec0 100644 --- a/workers/example_ara/worker.py +++ b/workers/example_ara/worker.py @@ -6,7 +6,7 @@ import time import uuid from shepherd_utils.db import get_message -from shepherd_utils.shared import get_tasks, wrap_up_task +from shepherd_utils.shared import get_tasks, handle_task_failure, wrap_up_task from shepherd_utils.otel import setup_tracer # Queue name @@ -18,46 +18,58 @@ async def example_ara(task, logger: logging.Logger): - try: - start = time.time() - # given a task, get the message from the db - logger.info("Getting message from db") - message = await get_message(task[1]["query_id"], logger) - # logger.info(message) - logger.info(task) - - workflow = [ - {"id": "example.lookup"}, - {"id": "example.score"}, - {"id": "sort_results_score"}, - {"id": "filter_results_top_n"}, - {"id": "filter_kgraph_orphans"}, - ] - except Exception as e: - logger.error(f"Something bad happened! {e}") - # TODO: gracefully handle worker errors + # given a task, get the message from the db + logger.info("Getting message from db") + message = await get_message(task[1]["query_id"], logger) + # logger.info(message) + logger.info(task) - await wrap_up_task(STREAM, GROUP, task, workflow, logger) + workflow = [ + {"id": "example.lookup"}, + {"id": "example.score"}, + {"id": "sort_results_score"}, + {"id": "filter_results_top_n"}, + {"id": "filter_kgraph_orphans"}, + ] - logger.info(f"Finished task {task[0]} in {time.time() - start}") + task[1]["workflow"] = workflow -async def process_task(task, parent_ctx, logger, limiter): +async def process_task(task, parent_ctx, logger: logging.Logger, limiter): + """Process a given task and ACK in redis.""" + start = time.time() span = tracer.start_span(STREAM, context=parent_ctx) try: await example_ara(task, logger) + # Always wrap up the task to ACK it in the broker + try: + await wrap_up_task(STREAM, GROUP, task, logger) + except Exception as e: + logger.error(f"Task {task[0]}: Failed to wrap up task: {e}") + except asyncio.CancelledError: + logger.warning(f"Task {task[0]} was cancelled") except Exception as e: - logger.error(f"Something went wrong: {e}") + logger.error(f"Task {task[0]} failed with unhandled error: {e}", exc_info=True) + await handle_task_failure(STREAM, GROUP, task, logger) finally: span.end() limiter.release() + logger.info(f"Finished task {task[0]} in {time.time() - start}") async def poll_for_tasks(): - async for task, parent_ctx, logger, limiter in get_tasks( - STREAM, GROUP, CONSUMER, TASK_LIMIT - ): - asyncio.create_task(process_task(task, parent_ctx, logger, limiter)) + """On initialization, poll indefinitely for available tasks.""" + while True: + try: + async for task, parent_ctx, logger, limiter in get_tasks( + STREAM, GROUP, CONSUMER, TASK_LIMIT + ): + asyncio.create_task(process_task(task, parent_ctx, logger, limiter)) + except asyncio.CancelledError: + logging.info("Poll loop cancelled, shutting down.") + except Exception as e: + logging.error(f"Error in task polling loop: {e}", exc_info=True) + await asyncio.sleep(5) # back off before retrying if __name__ == "__main__": diff --git a/workers/example_lookup/worker.py b/workers/example_lookup/worker.py index 735c2e5..a0832e7 100644 --- a/workers/example_lookup/worker.py +++ b/workers/example_lookup/worker.py @@ -9,14 +9,16 @@ import httpx +from shepherd_utils.config import settings from shepherd_utils.db import ( add_callback_id, + cleanup_callbacks, get_message, get_running_callbacks, save_message, ) from shepherd_utils.otel import setup_tracer -from shepherd_utils.shared import get_tasks, wrap_up_task +from shepherd_utils.shared import get_tasks, handle_task_failure, wrap_up_task # Queue name STREAM = "example.lookup" @@ -31,15 +33,11 @@ async def example_lookup(task, logger: logging.Logger): Just sends a test response back to the server callback endpoint. """ - start = time.time() # given a task, get the message from the db query_id = task[1]["query_id"] - workflow = json.loads(task[1]["workflow"]) - try: - message = await get_message(query_id, logger) - except Exception as e: - logger.error(f"Task {task[0]}: Failed to get message for query {query_id}: {e}") - raise + message = await get_message(query_id, logger) + parameters = message.get("parameters") or {} + parameters["timeout"] = parameters.get("timeout", settings.lookup_timeout) # Do query expansion or whatever lookup process # We're going to stub a response @@ -93,57 +91,56 @@ async def example_lookup(task, logger: logging.Logger): # this worker might have a timeout set for if the lookups don't finish within a # certain amount of time - MAX_QUERY_TIME = 300 + MAX_QUERY_TIME = message["parameters"]["timeout"] start_time = time.time() - try: - while time.time() - start_time < MAX_QUERY_TIME: - try: - # see if there are existing lookups going - running_callback_ids = await get_running_callbacks(query_id, logger) - except Exception as e: - logger.error(f"Task {task[0]}: Failed to check running callbacks: {e}") - # Brief backoff then retry the check rather than giving up - await asyncio.sleep(5) - continue - # logger.info(f"Got back {len(running_callback_ids)} running lookups") - # if there aren't, lookup is complete and we need to pass on to next - # workflow operation - if len(running_callback_ids) == 0: - break - - await asyncio.sleep(1) - except asyncio.CancelledError: - logger.warning(f"Task {task[0]}: Cancelled while waiting for callbacks.") - - # Always wrap up the task to ACK it in the broker - try: - await wrap_up_task(STREAM, GROUP, task, workflow, logger) - except Exception as e: - logger.error(f"Task {task[0]}: Failed to wrap up task: {e}") - raise - logger.info(f"Finished task {task[0]} in {time.time() - start}") + running_callback_ids = [""] + while time.time() - start_time < MAX_QUERY_TIME: + try: + # see if there are existing lookups going + running_callback_ids = await get_running_callbacks(query_id, logger) + except Exception: + # Brief backoff then retry the check rather than giving up + await asyncio.sleep(5) + continue + # if there aren't, lookup is complete and we need to pass on to next + # workflow operation + if len(running_callback_ids) == 0: + break + + await asyncio.sleep(1) + + if time.time() - start_time > MAX_QUERY_TIME: + logger.warning( + f"Timed out getting lookup callbacks. {len(running_callback_ids)} queries were still running..." + ) + # logger.warning(f"Running callbacks: {running_callback_ids}") + await cleanup_callbacks(query_id, logger) async def process_task(task, parent_ctx, logger: logging.Logger, limiter): + """Process a given task and ACK in redis.""" + start = time.time() span = tracer.start_span(STREAM, context=parent_ctx) try: await example_lookup(task, logger) + # Always wrap up the task to ACK it in the broker + try: + await wrap_up_task(STREAM, GROUP, task, logger) + except Exception as e: + logger.error(f"Task {task[0]}: Failed to wrap up task: {e}") except asyncio.CancelledError: logger.warning(f"Task {task[0]} was cancelled") except Exception as e: logger.error(f"Task {task[0]} failed with unhandled error: {e}", exc_info=True) - # Attempt to ACK the task so it doesn't get redelivered forever - try: - workflow = json.loads(task[1])["workflow"] - await wrap_up_task(STREAM, GROUP, task, workflow, logger) - except Exception as wrap_err: - logger.error(f"Task {task[0]}: Also failed wrap up after error: {wrap_err}") + await handle_task_failure(STREAM, GROUP, task, logger) finally: span.end() limiter.release() + logger.info(f"Finished task {task[0]} in {time.time() - start}") async def poll_for_tasks(): + """On initialization, poll indefinitely for available tasks.""" while True: try: async for task, parent_ctx, logger, limiter in get_tasks( diff --git a/workers/example_score/worker.py b/workers/example_score/worker.py index ee9f7bb..08ba9fb 100644 --- a/workers/example_score/worker.py +++ b/workers/example_score/worker.py @@ -7,7 +7,7 @@ import time import uuid from shepherd_utils.db import get_message, save_message -from shepherd_utils.shared import get_tasks, wrap_up_task +from shepherd_utils.shared import get_tasks, handle_task_failure, wrap_up_task from shepherd_utils.otel import setup_tracer # Queue name @@ -19,36 +19,54 @@ async def example_score(task, logger: logging.Logger): - start = time.time() + """Do a very random score.""" # given a task, get the message from the db response_id = task[1]["response_id"] - workflow = json.loads(task[1]["workflow"]) message = await get_message(response_id, logger) # give a random score to all results - for result in message["message"]["results"]: + for result in message["message"].get("results", []): for analysis in result["analyses"]: analysis["score"] = random.random() await save_message(response_id, message, logger) - await wrap_up_task(STREAM, GROUP, task, workflow, logger) - logger.info(f"Finished task {task[0]} in {time.time() - start}") -async def process_task(task, parent_ctx, logger, limiter): +async def process_task(task, parent_ctx, logger: logging.Logger, limiter): + """Process a given task and ACK in redis.""" + start = time.time() span = tracer.start_span(STREAM, context=parent_ctx) try: await example_score(task, logger) + # Always wrap up the task to ACK it in the broker + try: + await wrap_up_task(STREAM, GROUP, task, logger) + except Exception as e: + logger.error(f"Task {task[0]}: Failed to wrap up task: {e}") + except asyncio.CancelledError: + logger.warning(f"Task {task[0]} was cancelled") + except Exception as e: + logger.error(f"Task {task[0]} failed with unhandled error: {e}", exc_info=True) + await handle_task_failure(STREAM, GROUP, task, logger) finally: span.end() limiter.release() + logger.info(f"Finished task {task[0]} in {time.time() - start}") async def poll_for_tasks(): - async for task, parent_ctx, logger, limiter in get_tasks( - STREAM, GROUP, CONSUMER, TASK_LIMIT - ): - asyncio.create_task(process_task(task, parent_ctx, logger, limiter)) + """On initialization, poll indefinitely for available tasks.""" + while True: + try: + async for task, parent_ctx, logger, limiter in get_tasks( + STREAM, GROUP, CONSUMER, TASK_LIMIT + ): + asyncio.create_task(process_task(task, parent_ctx, logger, limiter)) + except asyncio.CancelledError: + logging.info("Poll loop cancelled, shutting down.") + except Exception as e: + logging.error(f"Error in task polling loop: {e}", exc_info=True) + await asyncio.sleep(5) # back off before retrying if __name__ == "__main__": diff --git a/workers/filter_analyses_top_n/worker.py b/workers/filter_analyses_top_n/worker.py index 92e0ed7..9015f40 100644 --- a/workers/filter_analyses_top_n/worker.py +++ b/workers/filter_analyses_top_n/worker.py @@ -6,7 +6,7 @@ import time import uuid from shepherd_utils.db import get_message, save_message, get_query_state -from shepherd_utils.shared import get_tasks, wrap_up_task +from shepherd_utils.shared import get_tasks, handle_task_failure, wrap_up_task from shepherd_utils.otel import setup_tracer # Queue name @@ -18,7 +18,7 @@ async def filter_analyses_top_n(task, logger: logging.Logger): - start = time.time() + """Filter results analyses to top n.""" # given a task, get the message from the db response_id = task[1]["response_id"] workflow = json.loads(task[1]["workflow"]) @@ -29,36 +29,49 @@ async def filter_analyses_top_n(task, logger: logging.Logger): logger.error(f"Unable to find operation {STREAM} in workflow") raise Exception(f"Operation {STREAM} is not in workflow") n = current_op.get("max_analyses", 1000) - try: - for ind, result in enumerate(message["message"]["results"]): - message["message"]["results"][ind]["analyses"] = result["analyses"][:n] - except KeyError as e: - # can't find the right structure of message - logger.error(f"Error filtering results: {e}") - return message, 400 + for ind, result in enumerate(results): + message["message"]["results"][ind]["analyses"] = result["analyses"][:n] logger.info("Returning filtered results.") # save merged message back to db await save_message(response_id, message, logger) - await wrap_up_task(STREAM, GROUP, task, workflow, logger) - logger.info(f"Finished task {task[0]} in {time.time() - start}") - -async def process_task(task, parent_ctx, logger, limiter): +async def process_task(task, parent_ctx, logger: logging.Logger, limiter): + """Process a given task and ACK in redis.""" + start = time.time() span = tracer.start_span(STREAM, context=parent_ctx) try: await filter_analyses_top_n(task, logger) + # Always wrap up the task to ACK it in the broker + try: + await wrap_up_task(STREAM, GROUP, task, logger) + except Exception as e: + logger.error(f"Task {task[0]}: Failed to wrap up task: {e}") + except asyncio.CancelledError: + logger.warning(f"Task {task[0]} was cancelled") + except Exception as e: + logger.error(f"Task {task[0]} failed with unhandled error: {e}", exc_info=True) + await handle_task_failure(STREAM, GROUP, task, logger) finally: span.end() limiter.release() + logger.info(f"Finished task {task[0]} in {time.time() - start}") async def poll_for_tasks(): - async for task, parent_ctx, logger, limiter in get_tasks( - STREAM, GROUP, CONSUMER, TASK_LIMIT - ): - asyncio.create_task(process_task(task, parent_ctx, logger, limiter)) + """On initialization, poll indefinitely for available tasks.""" + while True: + try: + async for task, parent_ctx, logger, limiter in get_tasks( + STREAM, GROUP, CONSUMER, TASK_LIMIT + ): + asyncio.create_task(process_task(task, parent_ctx, logger, limiter)) + except asyncio.CancelledError: + logging.info("Poll loop cancelled, shutting down.") + except Exception as e: + logging.error(f"Error in task polling loop: {e}", exc_info=True) + await asyncio.sleep(5) # back off before retrying if __name__ == "__main__": diff --git a/workers/filter_kgraph_orphans/worker.py b/workers/filter_kgraph_orphans/worker.py index 3c1140b..ead4723 100644 --- a/workers/filter_kgraph_orphans/worker.py +++ b/workers/filter_kgraph_orphans/worker.py @@ -11,6 +11,7 @@ from shepherd_utils.shared import ( filter_kgraph_orphans, get_tasks, + handle_task_failure, wrap_up_task, ) @@ -27,34 +28,50 @@ async def do_filter_kgraph_orphans(task, logger: logging.Logger): Given a TRAPI message, remove all kgraph nodes and edges that aren't referenced in any results. """ - start = time.time() # given a task, get the message from the db response_id = task[1]["response_id"] - workflow = json.loads(task[1]["workflow"]) message = await get_message(response_id, logger) filter_kgraph_orphans(message, logger) # save merged message back to db await save_message(response_id, message, logger) - await wrap_up_task(STREAM, GROUP, task, workflow, logger) - logger.info(f"Finished task {task[0]} in {time.time() - start}") - -async def process_task(task, parent_ctx, logger, limiter): +async def process_task(task, parent_ctx, logger: logging.Logger, limiter): + """Process a given task and ACK in redis.""" + start = time.time() span = tracer.start_span(STREAM, context=parent_ctx) try: await do_filter_kgraph_orphans(task, logger) + # Always wrap up the task to ACK it in the broker + try: + await wrap_up_task(STREAM, GROUP, task, logger) + except Exception as e: + logger.error(f"Task {task[0]}: Failed to wrap up task: {e}") + except asyncio.CancelledError: + logger.warning(f"Task {task[0]} was cancelled") + except Exception as e: + logger.error(f"Task {task[0]} failed with unhandled error: {e}", exc_info=True) + await handle_task_failure(STREAM, GROUP, task, logger) finally: span.end() limiter.release() + logger.info(f"Finished task {task[0]} in {time.time() - start}") async def poll_for_tasks(): - async for task, parent_ctx, logger, limiter in get_tasks( - STREAM, GROUP, CONSUMER, TASK_LIMIT - ): - asyncio.create_task(process_task(task, parent_ctx, logger, limiter)) + """On initialization, poll indefinitely for available tasks.""" + while True: + try: + async for task, parent_ctx, logger, limiter in get_tasks( + STREAM, GROUP, CONSUMER, TASK_LIMIT + ): + asyncio.create_task(process_task(task, parent_ctx, logger, limiter)) + except asyncio.CancelledError: + logging.info("Poll loop cancelled, shutting down.") + except Exception as e: + logging.error(f"Error in task polling loop: {e}", exc_info=True) + await asyncio.sleep(5) # back off before retrying if __name__ == "__main__": diff --git a/workers/filter_results_top_n/worker.py b/workers/filter_results_top_n/worker.py index bb83378..b37df9e 100644 --- a/workers/filter_results_top_n/worker.py +++ b/workers/filter_results_top_n/worker.py @@ -6,7 +6,7 @@ import time import uuid from shepherd_utils.db import get_message, save_message, get_query_state -from shepherd_utils.shared import get_tasks, wrap_up_task +from shepherd_utils.shared import get_tasks, handle_task_failure, wrap_up_task from shepherd_utils.otel import setup_tracer # Queue name @@ -18,7 +18,6 @@ async def filter_results_top_n(task, logger: logging.Logger): - start = time.time() # given a task, get the message from the db response_id = task[1]["response_id"] workflow = json.loads(task[1]["workflow"]) @@ -29,35 +28,49 @@ async def filter_results_top_n(task, logger: logging.Logger): logger.error(f"Unable to find operation {STREAM} in workflow") raise Exception(f"Operation {STREAM} is not in workflow") n = current_op.get("max_results", 500) - try: - message["message"]["results"] = results[:n] - except KeyError as e: - # can't find the right structure of message - logger.error(f"Error filtering results: {e}") - return message, 400 + + message["message"]["results"] = results[:n] logger.info("Returning filtered results.") # save merged message back to db await save_message(response_id, message, logger) - await wrap_up_task(STREAM, GROUP, task, workflow, logger) - logger.info(f"Finished task {task[0]} in {time.time() - start}") - -async def process_task(task, parent_ctx, logger, limiter): +async def process_task(task, parent_ctx, logger: logging.Logger, limiter): + """Process a given task and ACK in redis.""" + start = time.time() span = tracer.start_span(STREAM, context=parent_ctx) try: await filter_results_top_n(task, logger) + # Always wrap up the task to ACK it in the broker + try: + await wrap_up_task(STREAM, GROUP, task, logger) + except Exception as e: + logger.error(f"Task {task[0]}: Failed to wrap up task: {e}") + except asyncio.CancelledError: + logger.warning(f"Task {task[0]} was cancelled") + except Exception as e: + logger.error(f"Task {task[0]} failed with unhandled error: {e}", exc_info=True) + await handle_task_failure(STREAM, GROUP, task, logger) finally: span.end() limiter.release() + logger.info(f"Finished task {task[0]} in {time.time() - start}") async def poll_for_tasks(): - async for task, parent_ctx, logger, limiter in get_tasks( - STREAM, GROUP, CONSUMER, TASK_LIMIT - ): - asyncio.create_task(process_task(task, parent_ctx, logger, limiter)) + """On initialization, poll indefinitely for available tasks.""" + while True: + try: + async for task, parent_ctx, logger, limiter in get_tasks( + STREAM, GROUP, CONSUMER, TASK_LIMIT + ): + asyncio.create_task(process_task(task, parent_ctx, logger, limiter)) + except asyncio.CancelledError: + logging.info("Poll loop cancelled, shutting down.") + except Exception as e: + logging.error(f"Error in task polling loop: {e}", exc_info=True) + await asyncio.sleep(5) # back off before retrying if __name__ == "__main__": diff --git a/workers/finish_query/worker.py b/workers/finish_query/worker.py index 842d94c..d03ceda 100644 --- a/workers/finish_query/worker.py +++ b/workers/finish_query/worker.py @@ -23,13 +23,16 @@ CONSUMER = str(uuid.uuid4())[:8] TASK_LIMIT = 100 tracer = setup_tracer(STREAM) +CALLBACK_RETRIES = 3 async def finish_query(task, logger: logging.Logger): + """Do all the wrap up necessary for a query.""" start = time.time() # given a task, get the message from the db query_id = task[1]["query_id"] response_id = task[1]["response_id"] + status = task[1].get("status", "OK") query_state = await get_query_state(query_id, logger) if query_state is None: @@ -41,37 +44,59 @@ async def finish_query(task, logger: logging.Logger): message = await get_message(response_id, logger) logs = await get_logs(response_id, logger) message["logs"] = logs - try: - async with httpx.AsyncClient(timeout=60) as client: - response = await client.post( - callback_url, - json=message, - ) - response.raise_for_status() - logger.info(f"Sent response back to {callback_url}") - except Exception as e: - logger.error(f"Failed to send callback to {callback_url}: {e}") + for attempt in range(CALLBACK_RETRIES): + try: + async with httpx.AsyncClient(timeout=120) as client: + response = await client.post( + callback_url, + json=message, + ) + response.raise_for_status() + logger.info(f"Sent response back to {callback_url}") + break + except Exception as e: + logger.error(f"Failed to send callback to {callback_url}: {e}") + await asyncio.sleep(1 * (2**attempt)) - await set_query_completed(query_id, "OK", logger) + await set_query_completed(query_id, status, logger) - await mark_task_as_complete(STREAM, GROUP, task[0], logger) logger.info(f"Finished task {task[0]} in {time.time() - start}") -async def process_task(task, parent_ctx, logger, limiter): +async def process_task(task, parent_ctx, logger: logging.Logger, limiter): + """Process a given task and ACK in redis.""" + start = time.time() span = tracer.start_span(STREAM, context=parent_ctx) try: await finish_query(task, logger) + except asyncio.CancelledError: + logger.warning(f"Task {task[0]} was cancelled") + except Exception as e: + logger.error(f"Task {task[0]} failed with unhandled error: {e}", exc_info=True) finally: + # Always wrap up the task to ACK it in the broker + try: + await mark_task_as_complete(STREAM, GROUP, task[0], logger) + except Exception as e: + logger.error(f"Task {task[0]}: Failed to wrap up task: {e}") span.end() limiter.release() + logger.info(f"Finished task {task[0]} in {time.time() - start}") async def poll_for_tasks(): - async for task, parent_ctx, logger, limiter in get_tasks( - STREAM, GROUP, CONSUMER, TASK_LIMIT - ): - asyncio.create_task(process_task(task, parent_ctx, logger, limiter)) + """On initialization, poll indefinitely for available tasks.""" + while True: + try: + async for task, parent_ctx, logger, limiter in get_tasks( + STREAM, GROUP, CONSUMER, TASK_LIMIT + ): + asyncio.create_task(process_task(task, parent_ctx, logger, limiter)) + except asyncio.CancelledError: + logging.info("Poll loop cancelled, shutting down.") + except Exception as e: + logging.error(f"Error in task polling loop: {e}", exc_info=True) + await asyncio.sleep(5) # back off before retrying if __name__ == "__main__": diff --git a/workers/gandalf/worker.py b/workers/gandalf/worker.py index a8c539d..5ea4df9 100644 --- a/workers/gandalf/worker.py +++ b/workers/gandalf/worker.py @@ -82,75 +82,85 @@ async def poll_for_tasks(graph: CSRGraph, bmt: Toolkit): loop = asyncio.get_running_loop() executor = ThreadPoolExecutor(max_workers=1) - async for task, parent_ctx, task_logger, limiter in get_tasks( - STREAM, GROUP, CONSUMER, TASK_LIMIT - ): - span = tracer.start_span(STREAM, context=parent_ctx) - start = time.time() + while True: try: - task_logger.info("Got task for Gandalf") - response_id = task[1]["response_id"] - callback_id = task[1]["callback_id"] - target = task[1].get("target", "unknown") - - task_logger.info("Getting message") - message = await get_message(callback_id, task_logger) - if message is None: - task_logger.error(f"Failed to get {response_id} for lookup.") - continue - - query_id = await get_callback_query_id(callback_id, task_logger) - task_logger.info(f"Got original query id: {query_id}") - if query_id is None: - task_logger.error("Failed to get original query id.") - continue - - query_state = await get_query_state(query_id, task_logger) - if query_state is None: - task_logger.error("Failed to get query state.") - continue - - response_id = query_state[7] - - lookup_response = await loop.run_in_executor( - executor, - gandalf_lookup, - graph, - bmt, - message, - task_logger, - ) - - if DEBUG_RESPONSES and len(lookup_response["message"]["results"]) > 0: - debug_dir = Path("debug") - debug_dir.mkdir(exist_ok=True) - debug_path = debug_dir / f"{query_id}_{callback_id}_response.json" - with open(debug_path, "w", encoding="utf-8") as f: - json.dump(lookup_response, f, indent=2) - - task_logger.info(f"Saving callback {callback_id} to redis") - await save_message(callback_id, lookup_response, task_logger) - task_logger.info(f"Saved callback {callback_id} to redis") - - await add_task( - "merge_message", - { - "target": target, - "query_id": query_id, - "response_id": response_id, - "callback_id": callback_id, - "log_level": task[1].get("log_level", 20), - "otel": task[1]["otel"], - }, - task_logger, - ) - except Exception: - task_logger.exception(f"Task {task[0]} failed") - finally: - await mark_task_as_complete(STREAM, GROUP, task[0], logger) - task_logger.info(f"Finished task {task[0]} in {time.time() - start:.2f}s") - span.end() - limiter.release() + async for task, parent_ctx, task_logger, limiter in get_tasks( + STREAM, GROUP, CONSUMER, TASK_LIMIT + ): + span = tracer.start_span(STREAM, context=parent_ctx) + start = time.time() + try: + task_logger.info("Got task for Gandalf") + response_id = task[1]["response_id"] + callback_id = task[1]["callback_id"] + target = task[1].get("target", "unknown") + + task_logger.info("Getting message") + message = await get_message(callback_id, task_logger) + if message is None: + task_logger.error(f"Failed to get {response_id} for lookup.") + continue + + query_id = await get_callback_query_id(callback_id, task_logger) + task_logger.info(f"Got original query id: {query_id}") + if query_id is None: + task_logger.error("Failed to get original query id.") + continue + + query_state = await get_query_state(query_id, task_logger) + if query_state is None: + task_logger.error("Failed to get query state.") + continue + + response_id = query_state[7] + + lookup_response = await loop.run_in_executor( + executor, + gandalf_lookup, + graph, + bmt, + message, + task_logger, + ) + + if DEBUG_RESPONSES and len(lookup_response["message"]["results"]) > 0: + debug_dir = Path("debug") + debug_dir.mkdir(exist_ok=True) + debug_path = debug_dir / f"{query_id}_{callback_id}_response.json" + with open(debug_path, "w", encoding="utf-8") as f: + json.dump(lookup_response, f, indent=2) + + task_logger.info(f"Saving callback {callback_id} to redis") + await save_message(callback_id, lookup_response, task_logger) + task_logger.info(f"Saved callback {callback_id} to redis") + + await add_task( + "merge_message", + { + "target": target, + "query_id": query_id, + "response_id": response_id, + "callback_id": callback_id, + "log_level": task[1].get("log_level", 20), + "otel": task[1]["otel"], + }, + task_logger, + ) + except Exception: + task_logger.exception(f"Task {task[0]} failed") + finally: + try: + await mark_task_as_complete(STREAM, GROUP, task[0], logger) + except Exception as e: + logger.error(f"Task {task[0]}: Failed to wrap up task: {e}") + task_logger.info(f"Finished task {task[0]} in {time.time() - start:.2f}s") + span.end() + limiter.release() + except asyncio.CancelledError: + logging.info("Poll loop cancelled, shutting down.") + except Exception as e: + logging.error(f"Error in task polling loop: {e}", exc_info=True) + await asyncio.sleep(5) # back off before retrying if __name__ == "__main__": diff --git a/workers/gandalf_rehydrate/worker.py b/workers/gandalf_rehydrate/worker.py index c51b985..c5647dc 100644 --- a/workers/gandalf_rehydrate/worker.py +++ b/workers/gandalf_rehydrate/worker.py @@ -18,7 +18,7 @@ save_message, ) from shepherd_utils.otel import setup_tracer -from shepherd_utils.shared import get_tasks, wrap_up_task +from shepherd_utils.shared import get_tasks, handle_task_failure, wrap_up_task # Queue name STREAM = "gandalf.rehydrate" @@ -80,47 +80,57 @@ async def poll_for_tasks(graph: CSRGraph, bmt: Toolkit): loop = asyncio.get_running_loop() executor = ThreadPoolExecutor(max_workers=1) - async for task, parent_ctx, task_logger, limiter in get_tasks( - STREAM, GROUP, CONSUMER, TASK_LIMIT - ): - span = tracer.start_span(STREAM, context=parent_ctx) - start = time.time() - task_logger.info("Got task for Gandalf Rehydration") - workflow = json.loads(task[1]["workflow"]) - response_id = task[1]["response_id"] + while True: try: - message = await get_message(response_id, task_logger) - if message is None: - task_logger.error(f"Failed to get {response_id} for rehydration.") - continue - - hydrated_response = await loop.run_in_executor( - executor, - gandalf_rehydration, - graph, - bmt, - message, - task_logger, - ) - - if DEBUG_RESPONSES and len(hydrated_response["message"]["results"]) > 0: - debug_dir = Path("debug") - debug_dir.mkdir(exist_ok=True) - debug_path = debug_dir / f"{response_id}_response.json" - with open(debug_path, "w", encoding="utf-8") as f: - json.dump(hydrated_response, f, indent=2) - - task_logger.info(f"Saving response {response_id} to redis") - await save_message(response_id, hydrated_response, task_logger) - task_logger.info(f"Saved response {response_id} to redis") - - except Exception: - task_logger.exception(f"Task {task[0]} failed") - finally: - await wrap_up_task(STREAM, GROUP, task, workflow, logger) - task_logger.info(f"Finished task {task[0]} in {time.time() - start:.2f}s") - span.end() - limiter.release() + async for task, parent_ctx, task_logger, limiter in get_tasks( + STREAM, GROUP, CONSUMER, TASK_LIMIT + ): + span = tracer.start_span(STREAM, context=parent_ctx) + start = time.time() + task_logger.info("Got task for Gandalf Rehydration") + workflow = json.loads(task[1]["workflow"]) + response_id = task[1]["response_id"] + try: + message = await get_message(response_id, task_logger) + if message is None: + task_logger.error(f"Failed to get {response_id} for rehydration.") + continue + + hydrated_response = await loop.run_in_executor( + executor, + gandalf_rehydration, + graph, + bmt, + message, + task_logger, + ) + + if DEBUG_RESPONSES and len(hydrated_response["message"]["results"]) > 0: + debug_dir = Path("debug") + debug_dir.mkdir(exist_ok=True) + debug_path = debug_dir / f"{response_id}_response.json" + with open(debug_path, "w", encoding="utf-8") as f: + json.dump(hydrated_response, f, indent=2) + + await save_message(response_id, hydrated_response, task_logger) + # Always wrap up the task to ACK it in the broker + try: + await wrap_up_task(STREAM, GROUP, task, logger) + except Exception as e: + logger.error(f"Task {task[0]}: Failed to wrap up task: {e}") + + except Exception: + task_logger.exception(f"Task {task[0]} failed") + await handle_task_failure(STREAM, GROUP, task, logger) + finally: + task_logger.info(f"Finished task {task[0]} in {time.time() - start:.2f}s") + span.end() + limiter.release() + except asyncio.CancelledError: + logging.info("Poll loop cancelled, shutting down.") + except Exception as e: + logging.error(f"Error in task polling loop: {e}", exc_info=True) + await asyncio.sleep(5) # back off before retrying if __name__ == "__main__": diff --git a/workers/merge_message/worker.py b/workers/merge_message/worker.py index 9c551b1..5e6861f 100644 --- a/workers/merge_message/worker.py +++ b/workers/merge_message/worker.py @@ -597,64 +597,80 @@ async def poll_for_tasks(): cpu_count = cpu_count if cpu_count is not None else 1 cpu_count = min(cpu_count, TASK_LIMIT) executor = ProcessPoolExecutor(max_workers=cpu_count) - async for task, parent_ctx, logger, limiter in get_tasks( - STREAM, GROUP, CONSUMER, cpu_count - ): - span = tracer.start_span(STREAM, context=parent_ctx) - query_id = task[1]["query_id"] - response_id = task[1]["response_id"] - callback_id = task[1]["callback_id"] - target = task[1]["target"] - got_lock = await acquire_lock(response_id, CONSUMER, logger) - if got_lock: - logger.info(f"[{callback_id}] Obtained lock.") - - # given a task, get the message from the db - original_query = await get_message(query_id, logger) - if original_query is None: - logger.error( - f"Failed to get original query for {query_id}. Discarding callback response." - ) - await remove_lock(response_id, CONSUMER, logger) - await remove_callback_id(callback_id, logger) - limiter.release() - await mark_task_as_complete(STREAM, GROUP, task[0], logger) - span.end() - continue - original_query_graph = original_query["message"]["query_graph"] - callback_response = await get_message(callback_id, logger) - lock_time = time.time() - original_response = await get_message(response_id, logger) - # do message merging - try: - merged_message = await loop.run_in_executor( - executor, - merge_messages, - target, - original_query_graph, - original_response, - callback_response, - logger, - ) - # save merged message back to db - await save_message(response_id, merged_message, logger) - except Exception: - logger.error( - f"[{callback_id}] Error merging message: {traceback.format_exc()}" - ) - logger.info( - f"[{callback_id}] Kept the lock for {time.time() - lock_time} seconds" - ) - # remove lock so others can now modify message - await remove_lock(response_id, CONSUMER, logger) - else: - logger.error( - f"Failed to obtain lock for {query_id}. Discarding callback response." - ) - await remove_callback_id(callback_id, logger) - limiter.release() - await mark_task_as_complete(STREAM, GROUP, task[0], logger) - span.end() + while True: + try: + async for task, parent_ctx, logger, limiter in get_tasks( + STREAM, GROUP, CONSUMER, cpu_count + ): + start = time.time() + span = tracer.start_span(STREAM, context=parent_ctx) + try: + query_id = task[1]["query_id"] + response_id = task[1]["response_id"] + callback_id = task[1]["callback_id"] + target = task[1]["target"] + got_lock = await acquire_lock(response_id, CONSUMER, logger) + if got_lock: + logger.info(f"[{callback_id}] Obtained lock.") + + # given a task, get the message from the db + original_query = await get_message(query_id, logger) + if original_query is None: + logger.error( + f"Failed to get original query for {query_id}. Discarding callback response." + ) + await remove_lock(response_id, CONSUMER, logger) + await remove_callback_id(callback_id, logger) + limiter.release() + await mark_task_as_complete(STREAM, GROUP, task[0], logger) + span.end() + continue + original_query_graph = original_query["message"]["query_graph"] + callback_response = await get_message(callback_id, logger) + lock_time = time.time() + original_response = await get_message(response_id, logger) + # do message merging + try: + merged_message = await loop.run_in_executor( + executor, + merge_messages, + target, + original_query_graph, + original_response, + callback_response, + logger, + ) + # save merged message back to db + await save_message(response_id, merged_message, logger) + except Exception: + logger.error( + f"[{callback_id}] Error merging message: {traceback.format_exc()}" + ) + logger.info( + f"[{callback_id}] Kept the lock for {time.time() - lock_time} seconds" + ) + # remove lock so others can now modify message + await remove_lock(response_id, CONSUMER, logger) + await remove_callback_id(callback_id, logger) + else: + logger.error( + f"Failed to obtain lock for {query_id}. Discarding callback response." + ) + except Exception as e: + logger.error(f"Task {task[0]} failed with unhandled error: {e}", exc_info=True) + finally: + try: + await mark_task_as_complete(STREAM, GROUP, task[0], logger) + except Exception as e: + logger.error(f"Task {task[0]}: Failed to wrap up task: {e}") + logger.info(f"Finished task {task[0]} in {time.time() - start:.2f}s") + span.end() + limiter.release() + except asyncio.CancelledError: + logging.info("Poll loop cancelled, shutting down.") + except Exception as e: + logging.error(f"Error in task polling loop: {e}", exc_info=True) + await asyncio.sleep(5) # back off before retrying if __name__ == "__main__": diff --git a/workers/sipr/worker.py b/workers/sipr/worker.py index c7c555e..f585c9f 100644 --- a/workers/sipr/worker.py +++ b/workers/sipr/worker.py @@ -1,6 +1,7 @@ """SIPR (Set-Input Page Rank) module.""" import asyncio +import json import logging import time import uuid @@ -15,7 +16,7 @@ save_message, ) from shepherd_utils.otel import setup_tracer -from shepherd_utils.shared import get_tasks, wrap_up_task +from shepherd_utils.shared import get_tasks, handle_task_failure, wrap_up_task # Queue name STREAM = "sipr" @@ -196,11 +197,6 @@ def distribute_weights(trapi_responses, target_nodes, logger): async def sipr(task, logger: logging.Logger): - start = time.time() - workflow = [ - {"id": "sipr"}, - {"id": "sort_results_score"}, - ] try: # given a task, get the message from the db logger.info("Getting message from db") @@ -332,30 +328,46 @@ async def sipr(task, logger: logging.Logger): except NotImplementedError: logger.info("SIPR only supports Set Input Queries.") - except Exception as e: - logger.error(f"Something bad happened! {e}") - - await wrap_up_task(STREAM, GROUP, task, workflow, logger) - - logger.info(f"Finished task {task[0]} in {time.time() - start}") + task[1]["workflow"] = json.dumps([ + {"id": "sipr"}, + {"id": "sort_results_score"}, + ]) async def process_task(task, parent_ctx, logger, limiter): + """Process a given task and ACK in redis.""" + start = time.time() span = tracer.start_span(STREAM, context=parent_ctx) try: await sipr(task, logger) + try: + await wrap_up_task(STREAM, GROUP, task, logger) + except Exception as e: + logger.error(f"Task {task[0]}: Failed to wrap up task: {e}") + except asyncio.CancelledError: + logger.warning(f"Task {task[0]} was cancelled.") except Exception as e: - logger.error(f"Something went wrong: {e}") + logger.error(f"Task {task[0]} failed with unhandled error: {e}", exc_info=True) + await handle_task_failure(STREAM, GROUP, task, logger) finally: span.end() limiter.release() + logger.info(f"Task took {time.time() - start}") async def poll_for_tasks(): - async for task, parent_ctx, logger, limiter in get_tasks( - STREAM, GROUP, CONSUMER, TASK_LIMIT - ): - asyncio.create_task(process_task(task, parent_ctx, logger, limiter)) + """On initialization, poll indefinitely for available tasks.""" + while True: + try: + async for task, parent_ctx, logger, limiter in get_tasks( + STREAM, GROUP, CONSUMER, TASK_LIMIT + ): + asyncio.create_task(process_task(task, parent_ctx, logger, limiter)) + except asyncio.CancelledError: + logging.info("Poll loop cancelled, shutting down.") + except Exception as e: + logging.error(f"Error in task polling loop: {e}", exc_info=True) + await asyncio.sleep(5) # back off before retrying if __name__ == "__main__": diff --git a/workers/sort_results_score/worker.py b/workers/sort_results_score/worker.py index 00ce2f4..ee96ff0 100644 --- a/workers/sort_results_score/worker.py +++ b/workers/sort_results_score/worker.py @@ -6,7 +6,7 @@ import time import uuid from shepherd_utils.db import get_message, save_message, get_query_state -from shepherd_utils.shared import get_tasks, wrap_up_task +from shepherd_utils.shared import get_tasks, handle_task_failure, wrap_up_task from shepherd_utils.otel import setup_tracer # Queue name @@ -18,7 +18,6 @@ async def sort_results_score(task, logger: logging.Logger): - start = time.time() # given a task, get the message from the db response_id = task[1]["response_id"] workflow = json.loads(task[1]["workflow"]) @@ -27,53 +26,64 @@ async def sort_results_score(task, logger: logging.Logger): current_op = workflow[0] aord = current_op.get("ascending_or_descending", "descending") reverse = aord == "descending" - try: - for ind, result in enumerate(results): - message["message"]["results"][ind]["analyses"] = sorted( - result["analyses"], - key=lambda x: x.get("score", 0), - reverse=reverse, - ) - if reverse: - message["message"]["results"] = sorted( - results, - key=lambda x: x["analyses"][0].get("score", 0), - reverse=reverse, - ) - else: - message["message"]["results"] = sorted( - results, - key=lambda x: x["analyses"][-1].get("score", 0), - reverse=reverse, - ) - except KeyError as e: - # can't find the right structure of message - err = f"Error sorting results: {e}" - logger.error(err) - raise KeyError(err) + for ind, result in enumerate(results): + message["message"]["results"][ind]["analyses"] = sorted( + result["analyses"], + key=lambda x: x.get("score", 0), + reverse=reverse, + ) + if reverse: + message["message"]["results"] = sorted( + results, + key=lambda x: x["analyses"][0].get("score", 0), + reverse=reverse, + ) + else: + message["message"]["results"] = sorted( + results, + key=lambda x: x["analyses"][-1].get("score", 0), + reverse=reverse, + ) logger.info("Returning sorted results.") # save merged message back to db await save_message(response_id, message, logger) - await wrap_up_task(STREAM, GROUP, task, workflow, logger) - logger.info(f"Finished task {task[0]} in {time.time() - start}") - async def process_task(task, parent_ctx, logger, limiter): + """Process a given task and ACK in redis.""" + start = time.time() span = tracer.start_span(STREAM, context=parent_ctx) try: await sort_results_score(task, logger) + try: + await wrap_up_task(STREAM, GROUP, task, logger) + except Exception as e: + logger.error(f"Task {task[0]}: Failed to wrap up task: {e}") + except asyncio.CancelledError: + logger.warning(f"Task {task[0]} was cancelled.") + except Exception as e: + logger.error(f"Task {task[0]} failed with unhandled error: {e}", exc_info=True) + await handle_task_failure(STREAM, GROUP, task, logger) finally: span.end() limiter.release() + logger.info(f"Task took {time.time() - start}") async def poll_for_tasks(): - async for task, parent_ctx, logger, limiter in get_tasks( - STREAM, GROUP, CONSUMER, TASK_LIMIT - ): - asyncio.create_task(process_task(task, parent_ctx, logger, limiter)) + """On initialization, poll indefinitely for available tasks.""" + while True: + try: + async for task, parent_ctx, logger, limiter in get_tasks( + STREAM, GROUP, CONSUMER, TASK_LIMIT + ): + asyncio.create_task(process_task(task, parent_ctx, logger, limiter)) + except asyncio.CancelledError: + logging.info("Poll loop cancelled, shutting down.") + except Exception as e: + logging.error(f"Error in task polling loop: {e}", exc_info=True) + await asyncio.sleep(5) # back off before retrying if __name__ == "__main__": From c095b2ee0e932e9bbc83b5dcaf4d4ee5751bd3e0 Mon Sep 17 00:00:00 2001 From: Max Wang Date: Wed, 4 Mar 2026 11:35:18 -0500 Subject: [PATCH 2/4] Bump gandalf version --- workers/gandalf/requirements.txt | 3 ++- workers/gandalf_rehydrate/requirements.txt | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/workers/gandalf/requirements.txt b/workers/gandalf/requirements.txt index 227b37b..64cd655 100644 --- a/workers/gandalf/requirements.txt +++ b/workers/gandalf/requirements.txt @@ -1 +1,2 @@ -gandalf-csr>=0.1.8 +gandalf-csr>=0.1.11 + diff --git a/workers/gandalf_rehydrate/requirements.txt b/workers/gandalf_rehydrate/requirements.txt index 227b37b..64cd655 100644 --- a/workers/gandalf_rehydrate/requirements.txt +++ b/workers/gandalf_rehydrate/requirements.txt @@ -1 +1,2 @@ -gandalf-csr>=0.1.8 +gandalf-csr>=0.1.11 + From 26343e9a724d8187835640d19feff729d1514c1a Mon Sep 17 00:00:00 2001 From: Max Wang Date: Wed, 4 Mar 2026 11:35:49 -0500 Subject: [PATCH 3/4] Bump patch version --- shepherd_server/openapi-config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/shepherd_server/openapi-config.yaml b/shepherd_server/openapi-config.yaml index e7f92ce..00b63f3 100644 --- a/shepherd_server/openapi-config.yaml +++ b/shepherd_server/openapi-config.yaml @@ -4,7 +4,7 @@ contact: x-id: https://github.com/maximusunc x-role: responsible developer description: '

Shepherd: Translator Autonomous Relay Agent Platform' -version: 0.6.7 +version: 0.6.8 servers: - description: Default server url: https://shepherd.renci.org From 3947b24cd3e90a4879408a959634352e9ab8656a Mon Sep 17 00:00:00 2001 From: Max Wang Date: Wed, 4 Mar 2026 11:36:35 -0500 Subject: [PATCH 4/4] Run black --- workers/aragorn_score/worker.py | 5 ++++- workers/arax_rank/worker.py | 9 +++++++-- workers/gandalf/worker.py | 13 ++++++++++--- workers/gandalf_rehydrate/worker.py | 13 ++++++++++--- workers/merge_message/worker.py | 9 +++++++-- workers/sipr/worker.py | 10 ++++++---- 6 files changed, 44 insertions(+), 15 deletions(-) diff --git a/workers/aragorn_score/worker.py b/workers/aragorn_score/worker.py index b12ebc3..db28de0 100644 --- a/workers/aragorn_score/worker.py +++ b/workers/aragorn_score/worker.py @@ -1256,7 +1256,10 @@ async def poll_for_tasks(): except asyncio.CancelledError: logger.warning(f"Task {task[0]} was cancelled") except Exception as e: - logger.error(f"Task {task[0]} failed with unhandled error: {e}", exc_info=True) + logger.error( + f"Task {task[0]} failed with unhandled error: {e}", + exc_info=True, + ) await handle_task_failure(STREAM, GROUP, task, logger) finally: logger.info(f"Finished task {task[0]} in {time.time() - start}") diff --git a/workers/arax_rank/worker.py b/workers/arax_rank/worker.py index 109bda1..783e7a4 100644 --- a/workers/arax_rank/worker.py +++ b/workers/arax_rank/worker.py @@ -116,7 +116,9 @@ async def poll_for_tasks() -> None: logger, ) if ranked_message is None: - logger.error("Ranking returned None. Returning original message.") + logger.error( + "Ranking returned None. Returning original message." + ) ranked_message = message await save_message(response_id, ranked_message, logger) else: @@ -130,7 +132,10 @@ async def poll_for_tasks() -> None: except asyncio.CancelledError: logger.warning(f"Task {task[0]} was cancelled") except Exception as e: - logger.error(f"Task {task[0]} failed with unhandled error: {e}", exc_info=True) + logger.error( + f"Task {task[0]} failed with unhandled error: {e}", + exc_info=True, + ) await handle_task_failure(STREAM, GROUP, task, logger) finally: logger.info(f"Finished task {task[0]} in {time.time() - start}") diff --git a/workers/gandalf/worker.py b/workers/gandalf/worker.py index 5ea4df9..a08c4c9 100644 --- a/workers/gandalf/worker.py +++ b/workers/gandalf/worker.py @@ -123,10 +123,15 @@ async def poll_for_tasks(graph: CSRGraph, bmt: Toolkit): task_logger, ) - if DEBUG_RESPONSES and len(lookup_response["message"]["results"]) > 0: + if ( + DEBUG_RESPONSES + and len(lookup_response["message"]["results"]) > 0 + ): debug_dir = Path("debug") debug_dir.mkdir(exist_ok=True) - debug_path = debug_dir / f"{query_id}_{callback_id}_response.json" + debug_path = ( + debug_dir / f"{query_id}_{callback_id}_response.json" + ) with open(debug_path, "w", encoding="utf-8") as f: json.dump(lookup_response, f, indent=2) @@ -153,7 +158,9 @@ async def poll_for_tasks(graph: CSRGraph, bmt: Toolkit): await mark_task_as_complete(STREAM, GROUP, task[0], logger) except Exception as e: logger.error(f"Task {task[0]}: Failed to wrap up task: {e}") - task_logger.info(f"Finished task {task[0]} in {time.time() - start:.2f}s") + task_logger.info( + f"Finished task {task[0]} in {time.time() - start:.2f}s" + ) span.end() limiter.release() except asyncio.CancelledError: diff --git a/workers/gandalf_rehydrate/worker.py b/workers/gandalf_rehydrate/worker.py index c5647dc..dbffeae 100644 --- a/workers/gandalf_rehydrate/worker.py +++ b/workers/gandalf_rehydrate/worker.py @@ -93,7 +93,9 @@ async def poll_for_tasks(graph: CSRGraph, bmt: Toolkit): try: message = await get_message(response_id, task_logger) if message is None: - task_logger.error(f"Failed to get {response_id} for rehydration.") + task_logger.error( + f"Failed to get {response_id} for rehydration." + ) continue hydrated_response = await loop.run_in_executor( @@ -105,7 +107,10 @@ async def poll_for_tasks(graph: CSRGraph, bmt: Toolkit): task_logger, ) - if DEBUG_RESPONSES and len(hydrated_response["message"]["results"]) > 0: + if ( + DEBUG_RESPONSES + and len(hydrated_response["message"]["results"]) > 0 + ): debug_dir = Path("debug") debug_dir.mkdir(exist_ok=True) debug_path = debug_dir / f"{response_id}_response.json" @@ -123,7 +128,9 @@ async def poll_for_tasks(graph: CSRGraph, bmt: Toolkit): task_logger.exception(f"Task {task[0]} failed") await handle_task_failure(STREAM, GROUP, task, logger) finally: - task_logger.info(f"Finished task {task[0]} in {time.time() - start:.2f}s") + task_logger.info( + f"Finished task {task[0]} in {time.time() - start:.2f}s" + ) span.end() limiter.release() except asyncio.CancelledError: diff --git a/workers/merge_message/worker.py b/workers/merge_message/worker.py index 5e6861f..1ffd246 100644 --- a/workers/merge_message/worker.py +++ b/workers/merge_message/worker.py @@ -657,13 +657,18 @@ async def poll_for_tasks(): f"Failed to obtain lock for {query_id}. Discarding callback response." ) except Exception as e: - logger.error(f"Task {task[0]} failed with unhandled error: {e}", exc_info=True) + logger.error( + f"Task {task[0]} failed with unhandled error: {e}", + exc_info=True, + ) finally: try: await mark_task_as_complete(STREAM, GROUP, task[0], logger) except Exception as e: logger.error(f"Task {task[0]}: Failed to wrap up task: {e}") - logger.info(f"Finished task {task[0]} in {time.time() - start:.2f}s") + logger.info( + f"Finished task {task[0]} in {time.time() - start:.2f}s" + ) span.end() limiter.release() except asyncio.CancelledError: diff --git a/workers/sipr/worker.py b/workers/sipr/worker.py index f585c9f..afe8a61 100644 --- a/workers/sipr/worker.py +++ b/workers/sipr/worker.py @@ -328,10 +328,12 @@ async def sipr(task, logger: logging.Logger): except NotImplementedError: logger.info("SIPR only supports Set Input Queries.") - task[1]["workflow"] = json.dumps([ - {"id": "sipr"}, - {"id": "sort_results_score"}, - ]) + task[1]["workflow"] = json.dumps( + [ + {"id": "sipr"}, + {"id": "sort_results_score"}, + ] + ) async def process_task(task, parent_ctx, logger, limiter):