diff --git a/shepherd_server/openapi-config.yaml b/shepherd_server/openapi-config.yaml
index e7f92ce..00b63f3 100644
--- a/shepherd_server/openapi-config.yaml
+++ b/shepherd_server/openapi-config.yaml
@@ -4,7 +4,7 @@ contact:
x-id: https://github.com/maximusunc
x-role: responsible developer
description: '
Shepherd: Translator Autonomous Relay Agent Platform'
-version: 0.6.7
+version: 0.6.8
servers:
- description: Default server
url: https://shepherd.renci.org
diff --git a/shepherd_utils/db.py b/shepherd_utils/db.py
index a4e1149..c0ed9be 100644
--- a/shepherd_utils/db.py
+++ b/shepherd_utils/db.py
@@ -184,16 +184,14 @@ async def get_message(
"""Get the message from db."""
message = {}
start = time.time()
- try:
- # print(f"Putting {query_id} on {ara_target} stream")
- message = await data_db_client.get(message_id)
- if message is not None:
- start_decomp = time.time()
- message = orjson.loads(zstandard.decompress(message))
- logger.debug(f"Decompression took {time.time() - start_decomp}")
- except Exception as e:
+ message = await data_db_client.get(message_id)
+ if message is None:
# failed to get message from db
- logger.error(f"Failed to get {message_id} from db: {e}")
+ raise Exception(f"Failed to get {message_id} from db: {e}")
+
+ start_decomp = time.time()
+ message = orjson.loads(zstandard.decompress(message))
+ logger.debug(f"Decompression took {time.time() - start_decomp}")
logger.debug(f"Getting message took {time.time() - start} seconds")
return message
@@ -341,6 +339,7 @@ async def get_running_callbacks(
continue
except Exception as e:
logger.error(f"Failed to get running lookups: {e}")
+ raise
return running_lookups
diff --git a/shepherd_utils/shared.py b/shepherd_utils/shared.py
index 1c11b90..65e6ee4 100644
--- a/shepherd_utils/shared.py
+++ b/shepherd_utils/shared.py
@@ -75,10 +75,11 @@ async def wrap_up_task(
stream: str,
group: str,
task: Tuple[str, dict],
- workflow: List[dict],
logger: logging.Logger,
):
"""Call the next task and mark this one as complete."""
+ workflow = json.loads(task[1]["workflow"])
+ logger.info(workflow)
# remove the operation we just did
if stream == workflow[0]["id"]:
# make sure the worker is in the workflow
@@ -106,6 +107,30 @@ async def wrap_up_task(
await save_logs(task[1]["response_id"], logger)
+async def handle_task_failure(
+ stream: str,
+ group: str,
+ task: Tuple[str, dict],
+ logger: logging.Logger,
+) -> None:
+ """Handle any full query failures."""
+ await mark_task_as_complete(stream, group, task[0], logger)
+ await save_logs(task[1]["response_id"], logger)
+ logger.error("Sending task straight to finish_query.")
+ await add_task(
+ "finish_query",
+ {
+ "query_id": task[1]["query_id"],
+ "response_id": task[1]["response_id"],
+ "workflow": "[]",
+ "log_level": task[1].get("log_level", 20),
+ "otel": task[1]["otel"],
+ "status": "ERROR",
+ },
+ logger,
+ )
+
+
def recursive_get_edge_support_graphs(
edge: str,
edges: set,
diff --git a/workers/aragorn/worker.py b/workers/aragorn/worker.py
index 3b60502..c1c317e 100644
--- a/workers/aragorn/worker.py
+++ b/workers/aragorn/worker.py
@@ -8,7 +8,7 @@
from shepherd_utils.db import get_message
from shepherd_utils.otel import setup_tracer
-from shepherd_utils.shared import get_tasks, wrap_up_task
+from shepherd_utils.shared import get_tasks, handle_task_failure, wrap_up_task
# Queue name
STREAM = "aragorn"
@@ -31,27 +31,27 @@ def examine_query(message):
# this can still fail if the input looks like e.g.:
# "query_graph": None
qedges = message.get("message", {}).get("query_graph", {}).get("edges", {})
- except:
+ except KeyError:
qedges = {}
try:
# this can still fail if the input looks like e.g.:
# "query_graph": None
qpaths = message.get("message", {}).get("query_graph", {}).get("paths", {})
- except:
+ except KeyError:
qpaths = {}
if len(qpaths) > 1:
- raise Exception("Only a single path is supported", 400)
+ raise Exception("Only a single path is supported")
if (len(qpaths) > 0) and (len(qedges) > 0):
- raise Exception("Mixed mode pathfinder queries are not supported", 400)
+ raise Exception("Mixed mode pathfinder queries are not supported")
pathfinder = len(qpaths) == 1
n_infer_edges = 0
for edge_id in qedges:
if qedges.get(edge_id, {}).get("knowledge_type", "lookup") == "inferred":
n_infer_edges += 1
if n_infer_edges > 1 and n_infer_edges:
- raise Exception("Only a single infer edge is supported", 400)
+ raise Exception("Only a single infer edge is supported")
if (n_infer_edges > 0) and (n_infer_edges < len(qedges)):
- raise Exception("Mixed infer and lookup queries not supported", 400)
+ raise Exception("Mixed infer and lookup queries not supported")
infer = n_infer_edges == 1
if not infer:
return infer, None, None, pathfinder
@@ -64,23 +64,19 @@ def examine_query(message):
else:
question_node = qnode_id
if answer_node is None:
- raise Exception("Both nodes of creative edge pinned", 400)
+ raise Exception("Both nodes of creative edge pinned")
if question_node is None:
- raise Exception("No nodes of creative edge pinned", 400)
+ raise Exception("No nodes of creative edge pinned")
return infer, question_node, answer_node, pathfinder
async def aragorn(task, logger: logging.Logger):
- start = time.time()
+ """Examine and define the Aragorn workflow."""
# given a task, get the message from the db
query_id = task[1]["query_id"]
workflow = json.loads(task[1]["workflow"])
message = await get_message(query_id, logger)
- try:
- infer, question_qnode, answer_qnode, pathfinder = examine_query(message)
- except Exception as e:
- logger.error(e)
- return None, 500
+ infer, question_qnode, answer_qnode, pathfinder = examine_query(message)
if workflow is None:
if infer:
@@ -113,24 +109,44 @@ async def aragorn(task, logger: logging.Logger):
{"id": "filter_kgraph_orphans"},
]
- await wrap_up_task(STREAM, GROUP, task, workflow, logger)
- logger.info(f"Task took {time.time() - start}")
+ task[1]["workflow"] = json.dumps(workflow)
async def process_task(task, parent_ctx, logger, limiter):
+ """Process a given task and ACK in redis."""
+ start = time.time()
span = tracer.start_span(STREAM, context=parent_ctx)
try:
await aragorn(task, logger)
+ try:
+ logger.info(task)
+ await wrap_up_task(STREAM, GROUP, task, logger)
+ except Exception as e:
+ logger.error(f"Task {task[0]}: Failed to wrap up task: {e}")
+ except asyncio.CancelledError:
+ logger.warning(f"Task {task[0]} was cancelled.")
+ except Exception as e:
+ logger.error(f"Task {task[0]} failed with unhandled error: {e}", exc_info=True)
+ await handle_task_failure(STREAM, GROUP, task, logger)
finally:
span.end()
limiter.release()
+ logger.info(f"Task took {time.time() - start}")
async def poll_for_tasks():
- async for task, parent_ctx, logger, limiter in get_tasks(
- STREAM, GROUP, CONSUMER, TASK_LIMIT
- ):
- asyncio.create_task(process_task(task, parent_ctx, logger, limiter))
+ """On initialization, poll indefinitely for available tasks."""
+ while True:
+ try:
+ async for task, parent_ctx, logger, limiter in get_tasks(
+ STREAM, GROUP, CONSUMER, TASK_LIMIT
+ ):
+ asyncio.create_task(process_task(task, parent_ctx, logger, limiter))
+ except asyncio.CancelledError:
+ logging.info("Poll loop cancelled, shutting down.")
+ except Exception as e:
+ logging.error(f"Error in task polling loop: {e}", exc_info=True)
+ await asyncio.sleep(5) # back off before retrying
if __name__ == "__main__":
diff --git a/workers/aragorn_lookup/worker.py b/workers/aragorn_lookup/worker.py
index 56e7c38..3ac7374 100644
--- a/workers/aragorn_lookup/worker.py
+++ b/workers/aragorn_lookup/worker.py
@@ -23,7 +23,7 @@
save_message,
)
from shepherd_utils.otel import setup_tracer
-from shepherd_utils.shared import add_task, get_tasks, wrap_up_task
+from shepherd_utils.shared import add_task, get_tasks, handle_task_failure, wrap_up_task
# Queue name
STREAM = "aragorn.lookup"
@@ -46,7 +46,7 @@ def examine_query(message):
# this can still fail if the input looks like e.g.:
# "query_graph": None
qedges = message.get("message", {}).get("query_graph", {}).get("edges", {})
- except:
+ except KeyError:
qedges = {}
n_infer_edges = 0
for edge_id in qedges:
@@ -54,9 +54,9 @@ def examine_query(message):
n_infer_edges += 1
pathfinder = n_infer_edges == 3
if n_infer_edges > 1 and n_infer_edges and not pathfinder:
- raise Exception("Only a single infer edge is supported", 400)
+ raise Exception("Only a single infer edge is supported")
if (n_infer_edges > 0) and (n_infer_edges < len(qedges)):
- raise Exception("Mixed infer and lookup queries not supported", 400)
+ raise Exception("Mixed infer and lookup queries not supported")
infer = n_infer_edges == 1
if not infer:
return infer, None, None, pathfinder
@@ -69,9 +69,9 @@ def examine_query(message):
else:
question_node = qnode_id
if answer_node is None:
- raise Exception("Both nodes of creative edge pinned", 400)
+ raise Exception("Both nodes of creative edge pinned")
if question_node is None:
- raise Exception("No nodes of creative edge pinned", 400)
+ raise Exception("No nodes of creative edge pinned")
return infer, question_node, answer_node, pathfinder
@@ -109,11 +109,10 @@ async def run_async_lookup(
async def aragorn_lookup(task, logger: logging.Logger):
- start = time.time()
+ """Do Aragorn lookup operation."""
# given a task, get the message from the db
query_id = task[1]["query_id"]
response_id = task[1]["response_id"]
- workflow = json.loads(task[1]["workflow"])
message = await get_message(query_id, logger)
parameters = message.get("parameters") or {}
parameters["timeout"] = parameters.get("timeout", settings.lookup_timeout)
@@ -241,9 +240,13 @@ async def aragorn_lookup(task, logger: logging.Logger):
start_time = time.time()
running_callback_ids = [""]
while time.time() - start_time < MAX_QUERY_TIME:
- # see if there are existing lookups going
- running_callback_ids = await get_running_callbacks(query_id, logger)
- # logger.info(f"Got back {len(running_callback_ids)} running lookups")
+ try:
+ # see if there are existing lookups going
+ running_callback_ids = await get_running_callbacks(query_id, logger)
+ except Exception:
+ # Brief backoff then retry the check rather than giving up
+ await asyncio.sleep(5)
+ continue
# if there are, continue to wait
if len(running_callback_ids) > 0:
await asyncio.sleep(1)
@@ -260,9 +263,6 @@ async def aragorn_lookup(task, logger: logging.Logger):
# logger.warning(f"Running callbacks: {running_callback_ids}")
await cleanup_callbacks(query_id, logger)
- await wrap_up_task(STREAM, GROUP, task, workflow, logger)
- logger.info(f"Finished task {task[0]} in {time.time() - start}")
-
def get_infer_parameters(input_message):
"""Given an infer input message, return the parameters needed to run the infer.
@@ -386,20 +386,41 @@ def expand_aragorn_query(input_message, logger: logging.Logger):
return messages
-async def process_task(task, parent_ctx, logger, limiter):
+async def process_task(task, parent_ctx, logger: logging.Logger, limiter):
+ """Process a given task and ACK in redis."""
+ start = time.time()
span = tracer.start_span(STREAM, context=parent_ctx)
try:
await aragorn_lookup(task, logger)
+ # Always wrap up the task to ACK it in the broker
+ try:
+ await wrap_up_task(STREAM, GROUP, task, logger)
+ except Exception as e:
+ logger.error(f"Task {task[0]}: Failed to wrap up task: {e}")
+ except asyncio.CancelledError:
+ logger.warning(f"Task {task[0]} was cancelled")
+ except Exception as e:
+ logger.error(f"Task {task[0]} failed with unhandled error: {e}", exc_info=True)
+ await handle_task_failure(STREAM, GROUP, task, logger)
finally:
span.end()
limiter.release()
+ logger.info(f"Finished task {task[0]} in {time.time() - start}")
async def poll_for_tasks():
- async for task, parent_ctx, logger, limiter in get_tasks(
- STREAM, GROUP, CONSUMER, TASK_LIMIT
- ):
- asyncio.create_task(process_task(task, parent_ctx, logger, limiter))
+ """On initialization, poll indefinitely for available tasks."""
+ while True:
+ try:
+ async for task, parent_ctx, logger, limiter in get_tasks(
+ STREAM, GROUP, CONSUMER, TASK_LIMIT
+ ):
+ asyncio.create_task(process_task(task, parent_ctx, logger, limiter))
+ except asyncio.CancelledError:
+ logging.info("Poll loop cancelled, shutting down.")
+ except Exception as e:
+ logging.error(f"Error in task polling loop: {e}", exc_info=True)
+ await asyncio.sleep(5) # back off before retrying
if __name__ == "__main__":
diff --git a/workers/aragorn_omnicorp/worker.py b/workers/aragorn_omnicorp/worker.py
index b2ae04e..9a46036 100644
--- a/workers/aragorn_omnicorp/worker.py
+++ b/workers/aragorn_omnicorp/worker.py
@@ -7,9 +7,9 @@
import time
import uuid
from shepherd_utils.config import settings
-from shepherd_utils.db import get_message
+from shepherd_utils.db import get_message, save_message
from shepherd_utils.otel import setup_tracer
-from shepherd_utils.shared import get_tasks, wrap_up_task
+from shepherd_utils.shared import get_tasks, handle_task_failure, wrap_up_task
# Queue name
STREAM = "aragorn.omnicorp"
@@ -21,36 +21,56 @@
async def aragorn_omnicorp(task, logger: logging.Logger):
- start = time.time()
# given a task, get the message from the db
response_id = task[1]["response_id"]
- workflow = json.loads(task[1]["workflow"])
message = await get_message(response_id, logger)
- async with httpx.AsyncClient(timeout=100) as client:
- await client.post(
+ async with httpx.AsyncClient(timeout=120) as client:
+ response = await client.post(
settings.omnicorp_url,
json=message,
)
+ response.raise_for_status()
- await wrap_up_task(STREAM, GROUP, task, workflow, logger)
- logger.info(f"Task took {time.time() - start}")
+ response = response.json()
+ await save_message(response_id, response, logger)
-async def process_task(task, parent_ctx, logger, limiter):
+async def process_task(task, parent_ctx, logger: logging.Logger, limiter):
+ """Process a given task and ACK in redis."""
+ start = time.time()
span = tracer.start_span(STREAM, context=parent_ctx)
try:
await aragorn_omnicorp(task, logger)
+ # Always wrap up the task to ACK it in the broker
+ try:
+ await wrap_up_task(STREAM, GROUP, task, logger)
+ except Exception as e:
+ logger.error(f"Task {task[0]}: Failed to wrap up task: {e}")
+ except asyncio.CancelledError:
+ logger.warning(f"Task {task[0]} was cancelled")
+ except Exception as e:
+ logger.error(f"Task {task[0]} failed with unhandled error: {e}", exc_info=True)
+ await handle_task_failure(STREAM, GROUP, task, logger)
finally:
span.end()
limiter.release()
+ logger.info(f"Finished task {task[0]} in {time.time() - start}")
async def poll_for_tasks():
- async for task, parent_ctx, logger, limiter in get_tasks(
- STREAM, GROUP, CONSUMER, TASK_LIMIT
- ):
- asyncio.create_task(process_task(task, parent_ctx, logger, limiter))
+ """On initialization, poll indefinitely for available tasks."""
+ while True:
+ try:
+ async for task, parent_ctx, logger, limiter in get_tasks(
+ STREAM, GROUP, CONSUMER, TASK_LIMIT
+ ):
+ asyncio.create_task(process_task(task, parent_ctx, logger, limiter))
+ except asyncio.CancelledError:
+ logging.info("Poll loop cancelled, shutting down.")
+ except Exception as e:
+ logging.error(f"Error in task polling loop: {e}", exc_info=True)
+ await asyncio.sleep(5) # back off before retrying
if __name__ == "__main__":
diff --git a/workers/aragorn_pathfinder/worker.py b/workers/aragorn_pathfinder/worker.py
index 9afe88c..2c2be4a 100644
--- a/workers/aragorn_pathfinder/worker.py
+++ b/workers/aragorn_pathfinder/worker.py
@@ -18,6 +18,7 @@
from shepherd_utils.shared import (
add_task,
get_tasks,
+ handle_task_failure,
wrap_up_task,
)
@@ -35,10 +36,8 @@ async def shadowfax(task, logger: logging.Logger):
co-occurrence to find nodes that occur in publications with our input
nodes, then finding paths that connect our input nodes through these
intermediate nodes."""
- start = time.time()
# given a task, get the message from the db
query_id = task[1]["query_id"]
- workflow = json.loads(task[1]["workflow"])
response_id = task[1]["response_id"]
message = await get_message(query_id, logger)
parameters = message.get("parameters") or {}
@@ -55,9 +54,7 @@ async def shadowfax(task, logger: logging.Logger):
# TODO: silently only grabbing the first id
pinned_node_ids.append(node["ids"][0])
if len(set(pinned_node_ids)) != 2:
- logger.error("Pathfinder queries require two pinned nodes.")
- # TODO: Update to wrap up task
- return message, 500
+ raise Exception("Pathfinder queries require two pinned nodes.")
intermediate_categories = []
path_key = next(iter(qgraph["paths"].keys()))
@@ -66,17 +63,15 @@ async def shadowfax(task, logger: logging.Logger):
constraints = qpath["constraints"]
# TODO: need to wrap up tasks
if len(constraints) > 1:
- logger.error("Pathfinder queries do not support multiple constraints.")
- return message, 500
+ raise Exception("Pathfinder queries do not support multiple constraints.")
if len(constraints) > 0:
intermediate_categories = (
constraints[0].get("intermediate_categories", None) or []
)
if len(intermediate_categories) > 1:
- logger.error(
+ raise Exception(
"Pathfinder queries do not support multiple intermediate categories"
)
- return message, 500
else:
intermediate_categories = ["biolink:NamedThing"]
@@ -184,9 +179,9 @@ async def shadowfax(task, logger: logging.Logger):
callback_id = str(uuid.uuid4())[:8]
# Put callback UID and query ID in postgres
await add_callback_id(query_id, callback_id, logger)
- logger.debug("""Sending pathfinder lookup query to gandalf.""")
await save_message(callback_id, threehop, logger)
+ logger.debug("""Sending pathfinder lookup query to gandalf.""")
await add_task(
"gandalf",
@@ -206,44 +201,65 @@ async def shadowfax(task, logger: logging.Logger):
MAX_QUERY_TIME = message["parameters"]["timeout"]
start_time = time.time()
running_callback_ids = [""]
- while time.time() - start_time < MAX_QUERY_TIME:
- # see if there are existing lookups going
- running_callback_ids = await get_running_callbacks(query_id, logger)
- # logger.info(f"Got back {len(running_callback_ids)} running lookups")
- # if there are, continue to wait
- if len(running_callback_ids) > 0:
- await asyncio.sleep(1)
- continue
- # if there aren't, lookup is complete and we need to pass on to next workflow operation
- if len(running_callback_ids) == 0:
- logger.debug("Got all lookups back. Continuing...")
- break
-
- if time.time() - start_time > MAX_QUERY_TIME:
- logger.warning(
- f"Timed out getting lookup callbacks. {len(running_callback_ids)} queries were still running..."
- )
- logger.warning(f"Running callbacks: {running_callback_ids}")
- await cleanup_callbacks(query_id, logger)
-
- await wrap_up_task(STREAM, GROUP, task, workflow, logger)
- logger.info(f"Task took {time.time() - start}")
-
-
-async def process_task(task, parent_ctx, logger, limiter):
+ try:
+ while time.time() - start_time < MAX_QUERY_TIME:
+ # see if there are existing lookups going
+ running_callback_ids = await get_running_callbacks(query_id, logger)
+ # logger.info(f"Got back {len(running_callback_ids)} running lookups")
+ # if there are, continue to wait
+ if len(running_callback_ids) > 0:
+ await asyncio.sleep(1)
+ continue
+ # if there aren't, lookup is complete and we need to pass on to next workflow operation
+ if len(running_callback_ids) == 0:
+ logger.debug("Got all lookups back. Continuing...")
+ break
+
+ if time.time() - start_time > MAX_QUERY_TIME:
+ logger.warning(
+ f"Timed out getting lookup callbacks. {len(running_callback_ids)} queries were still running..."
+ )
+ logger.warning(f"Running callbacks: {running_callback_ids}")
+ await cleanup_callbacks(query_id, logger)
+ except asyncio.CancelledError:
+ logger.warning(f"Task {task[0]}: Cancelled while waiting for callbacks")
+
+
+async def process_task(task, parent_ctx, logger: logging.Logger, limiter):
+ """Process a given task and ACK in redis."""
+ start = time.time()
span = tracer.start_span(STREAM, context=parent_ctx)
try:
await shadowfax(task, logger)
+ # Always wrap up the task to ACK it in the broker
+ try:
+ await wrap_up_task(STREAM, GROUP, task, logger)
+ except Exception as e:
+ logger.error(f"Task {task[0]}: Failed to wrap up task: {e}")
+ except asyncio.CancelledError:
+ logger.warning(f"Task {task[0]} was cancelled")
+ except Exception as e:
+ logger.error(f"Task {task[0]} failed with unhandled error: {e}", exc_info=True)
+ await handle_task_failure(STREAM, GROUP, task, logger)
finally:
span.end()
limiter.release()
+ logger.info(f"Finished task {task[0]} in {time.time() - start}")
async def poll_for_tasks():
- async for task, parent_ctx, logger, limiter in get_tasks(
- STREAM, GROUP, CONSUMER, TASK_LIMIT
- ):
- asyncio.create_task(process_task(task, parent_ctx, logger, limiter))
+ """On initialization, poll indefinitely for available tasks."""
+ while True:
+ try:
+ async for task, parent_ctx, logger, limiter in get_tasks(
+ STREAM, GROUP, CONSUMER, TASK_LIMIT
+ ):
+ asyncio.create_task(process_task(task, parent_ctx, logger, limiter))
+ except asyncio.CancelledError:
+ logging.info("Poll loop cancelled, shutting down.")
+ except Exception as e:
+ logging.error(f"Error in task polling loop: {e}", exc_info=True)
+ await asyncio.sleep(5) # back off before retrying
if __name__ == "__main__":
diff --git a/workers/aragorn_score/worker.py b/workers/aragorn_score/worker.py
index dbe492a..db28de0 100644
--- a/workers/aragorn_score/worker.py
+++ b/workers/aragorn_score/worker.py
@@ -16,7 +16,7 @@
from shepherd_utils.db import get_message, save_message
from shepherd_utils.otel import setup_tracer
-from shepherd_utils.shared import get_tasks, wrap_up_task
+from shepherd_utils.shared import get_tasks, handle_task_failure, wrap_up_task
# Queue name
STREAM = "aragorn.score"
@@ -1218,38 +1218,58 @@ def aragorn_score(in_message, logger: logging.Logger):
async def poll_for_tasks():
+ """On initialization, poll indefinitely for available tasks."""
loop = asyncio.get_running_loop()
cpu_count = os.cpu_count()
cpu_count = cpu_count if cpu_count is not None else 1
cpu_count = min(cpu_count, TASK_LIMIT)
executor = ProcessPoolExecutor(max_workers=cpu_count)
- async for task, parent_ctx, logger, limiter in get_tasks(
- STREAM, GROUP, CONSUMER, TASK_LIMIT
- ):
- span = tracer.start_span(STREAM, context=parent_ctx)
- start = time.time()
- # given a task, get the message from the db
- response_id = task[1]["response_id"]
- workflow = json.loads(task[1]["workflow"])
- message = await get_message(response_id, logger)
- if message is not None:
- scored_message = await loop.run_in_executor(
- executor,
- aragorn_score,
- message,
- logger,
- )
- if scored_message is None:
- logger.error("Failed to score message. Returning unscored.")
- scored_message = message
- await save_message(response_id, scored_message, logger)
- else:
- logger.error(f"Failed to get {response_id} for scoring.")
- await wrap_up_task(STREAM, GROUP, task, workflow, logger)
-
- logger.info(f"Finished task {task[0]} in {time.time() - start}")
- span.end()
- limiter.release()
+ while True:
+ try:
+ async for task, parent_ctx, logger, limiter in get_tasks(
+ STREAM, GROUP, CONSUMER, TASK_LIMIT
+ ):
+ start = time.time()
+ span = tracer.start_span(STREAM, context=parent_ctx)
+ try:
+ # given a task, get the message from the db
+ response_id = task[1]["response_id"]
+ message = await get_message(response_id, logger)
+ if message is not None:
+ scored_message = await loop.run_in_executor(
+ executor,
+ aragorn_score,
+ message,
+ logger,
+ )
+ if scored_message is None:
+ logger.error("Failed to score message. Returning unscored.")
+ scored_message = message
+ await save_message(response_id, scored_message, logger)
+ else:
+ logger.error(f"Failed to get {response_id} for scoring.")
+ # Always wrap up the task to ACK it in the broker
+ try:
+ await wrap_up_task(STREAM, GROUP, task, logger)
+ except Exception as e:
+ logger.error(f"Task {task[0]}: Failed to wrap up task: {e}")
+ except asyncio.CancelledError:
+ logger.warning(f"Task {task[0]} was cancelled")
+ except Exception as e:
+ logger.error(
+ f"Task {task[0]} failed with unhandled error: {e}",
+ exc_info=True,
+ )
+ await handle_task_failure(STREAM, GROUP, task, logger)
+ finally:
+ logger.info(f"Finished task {task[0]} in {time.time() - start}")
+ span.end()
+ limiter.release()
+ except asyncio.CancelledError:
+ logging.info("Poll loop cancelled, shutting down.")
+ except Exception as e:
+ logging.error(f"Error in task polling loop: {e}", exc_info=True)
+ await asyncio.sleep(5) # back off before retrying
if __name__ == "__main__":
diff --git a/workers/arax/worker.py b/workers/arax/worker.py
index cb34435..53909db 100644
--- a/workers/arax/worker.py
+++ b/workers/arax/worker.py
@@ -1,13 +1,14 @@
"""ARAX entry module."""
import asyncio
+import json
import logging
import requests
import time
import uuid
from shepherd_utils.config import settings
from shepherd_utils.db import get_message, save_message
-from shepherd_utils.shared import get_tasks, wrap_up_task
+from shepherd_utils.shared import get_tasks, handle_task_failure, wrap_up_task
from shepherd_utils.otel import setup_tracer
from inject_shepherd_arax_provenance import add_shepherd_arax_to_edge_sources
@@ -21,7 +22,6 @@
async def arax(task, logger: logging.Logger):
try:
- start = time.time()
query_id = task[1]["query_id"]
logger.info(f"Getting message from db for query id {query_id}")
message = await get_message(query_id, logger)
@@ -43,29 +43,44 @@ async def arax(task, logger: logging.Logger):
await save_message(response_id, result, logger)
- workflow = [{"id": "arax"}]
+ task[1]["workflow"] = json.dumps([{"id": "arax"}])
- await wrap_up_task(STREAM, GROUP, task, workflow, logger)
- logger.info(f"Finished task {task[0]} in {time.time() - start}")
-
-
-async def process_task(task, parent_ctx, logger, limiter):
+async def process_task(task, parent_ctx, logger: logging.Logger, limiter):
+ """Process a given task and ACK in redis."""
+ start = time.time()
span = tracer.start_span(STREAM, context=parent_ctx)
try:
await arax(task, logger)
+ # Always wrap up the task to ACK it in the broker
+ try:
+ await wrap_up_task(STREAM, GROUP, task, logger)
+ except Exception as e:
+ logger.error(f"Task {task[0]}: Failed to wrap up task: {e}")
+ except asyncio.CancelledError:
+ logger.warning(f"Task {task[0]} was cancelled")
except Exception as e:
- logger.error(f"Something went wrong: {e}")
+ logger.error(f"Task {task[0]} failed with unhandled error: {e}", exc_info=True)
+ await handle_task_failure(STREAM, GROUP, task, logger)
finally:
span.end()
limiter.release()
+ logger.info(f"Finished task {task[0]} in {time.time() - start}")
async def poll_for_tasks():
- async for task, parent_ctx, logger, limiter in get_tasks(
- STREAM, GROUP, CONSUMER, TASK_LIMIT
- ):
- asyncio.create_task(process_task(task, parent_ctx, logger, limiter))
+ """On initialization, poll indefinitely for available tasks."""
+ while True:
+ try:
+ async for task, parent_ctx, logger, limiter in get_tasks(
+ STREAM, GROUP, CONSUMER, TASK_LIMIT
+ ):
+ asyncio.create_task(process_task(task, parent_ctx, logger, limiter))
+ except asyncio.CancelledError:
+ logging.info("Poll loop cancelled, shutting down.")
+ except Exception as e:
+ logging.error(f"Error in task polling loop: {e}", exc_info=True)
+ await asyncio.sleep(5) # back off before retrying
if __name__ == "__main__":
diff --git a/workers/arax_rank/worker.py b/workers/arax_rank/worker.py
index 43d9f8f..783e7a4 100644
--- a/workers/arax_rank/worker.py
+++ b/workers/arax_rank/worker.py
@@ -20,7 +20,7 @@
from shepherd_utils.db import get_message, save_message
from shepherd_utils.otel import setup_tracer
-from shepherd_utils.shared import get_tasks, wrap_up_task
+from shepherd_utils.shared import get_tasks, handle_task_failure, wrap_up_task
from ranker import arax_rank
@@ -92,40 +92,60 @@ async def poll_for_tasks() -> None:
cpu_count = min(cpu_count, TASK_LIMIT)
executor = ProcessPoolExecutor(max_workers=cpu_count)
- async for task, parent_ctx, logger, limiter in get_tasks(
- STREAM, GROUP, CONSUMER, TASK_LIMIT
- ):
- span = tracer.start_span(STREAM, context=parent_ctx)
- start = time.time()
-
- # Get task details
- response_id = task[1]["response_id"]
- workflow = json.loads(task[1]["workflow"])
-
- # Get message from Redis
- message = await get_message(response_id, logger)
-
- if message is not None:
- # Run ranking in process pool for CPU-intensive operations
- ranked_message = await loop.run_in_executor(
- executor,
- rank_message,
- message,
- logger,
- )
- if ranked_message is None:
- logger.error("Ranking returned None. Returning original message.")
- ranked_message = message
- await save_message(response_id, ranked_message, logger)
- else:
- logger.error(f"Failed to get {response_id} for ranking.")
-
- # Pass to next operation in workflow
- await wrap_up_task(STREAM, GROUP, task, workflow, logger)
-
- logger.info(f"Finished task {task[0]} in {time.time() - start:.2f}s")
- span.end()
- limiter.release()
+ while True:
+ try:
+ async for task, parent_ctx, logger, limiter in get_tasks(
+ STREAM, GROUP, CONSUMER, TASK_LIMIT
+ ):
+ start = time.time()
+ span = tracer.start_span(STREAM, context=parent_ctx)
+
+ try:
+ # Get task details
+ response_id = task[1]["response_id"]
+
+ # Get message from Redis
+ message = await get_message(response_id, logger)
+
+ if message is not None:
+ # Run ranking in process pool for CPU-intensive operations
+ ranked_message = await loop.run_in_executor(
+ executor,
+ rank_message,
+ message,
+ logger,
+ )
+ if ranked_message is None:
+ logger.error(
+ "Ranking returned None. Returning original message."
+ )
+ ranked_message = message
+ await save_message(response_id, ranked_message, logger)
+ else:
+ logger.error(f"Failed to get {response_id} for ranking.")
+
+ # Always wrap up the task to ACK it in the broker
+ try:
+ await wrap_up_task(STREAM, GROUP, task, logger)
+ except Exception as e:
+ logger.error(f"Task {task[0]}: Failed to wrap up task: {e}")
+ except asyncio.CancelledError:
+ logger.warning(f"Task {task[0]} was cancelled")
+ except Exception as e:
+ logger.error(
+ f"Task {task[0]} failed with unhandled error: {e}",
+ exc_info=True,
+ )
+ await handle_task_failure(STREAM, GROUP, task, logger)
+ finally:
+ logger.info(f"Finished task {task[0]} in {time.time() - start}")
+ span.end()
+ limiter.release()
+ except asyncio.CancelledError:
+ logging.info("Poll loop cancelled, shutting down.")
+ except Exception as e:
+ logging.error(f"Error in task polling loop: {e}", exc_info=True)
+ await asyncio.sleep(5) # back off before retrying
if __name__ == "__main__":
diff --git a/workers/bte/worker.py b/workers/bte/worker.py
index b5b2220..5d8e890 100644
--- a/workers/bte/worker.py
+++ b/workers/bte/worker.py
@@ -8,7 +8,7 @@
from shepherd_utils.db import get_message
from shepherd_utils.otel import setup_tracer
-from shepherd_utils.shared import get_tasks, wrap_up_task
+from shepherd_utils.shared import get_tasks, handle_task_failure, wrap_up_task
# Queue name
STREAM = "bte"
@@ -31,7 +31,7 @@ def examine_query(message):
# this can still fail if the input looks like e.g.:
# "query_graph": None
qedges = message.get("message", {}).get("query_graph", {}).get("edges", {})
- except:
+ except KeyError:
qedges = {}
n_infer_edges = 0
for edge_id in qedges:
@@ -39,9 +39,9 @@ def examine_query(message):
n_infer_edges += 1
pathfinder = n_infer_edges == 3
if n_infer_edges > 1 and n_infer_edges and not pathfinder:
- raise Exception("Only a single infer edge is supported", 400)
+ raise Exception("Only a single infer edge is supported")
if (n_infer_edges > 0) and (n_infer_edges < len(qedges)):
- raise Exception("Mixed infer and lookup queries not supported", 400)
+ raise Exception("Mixed infer and lookup queries not supported")
infer = n_infer_edges == 1
if not infer:
return infer, None, None, pathfinder
@@ -54,27 +54,21 @@ def examine_query(message):
else:
question_node = qnode_id
if answer_node is None:
- raise Exception("Both nodes of creative edge pinned", 400)
+ raise Exception("Both nodes of creative edge pinned")
if question_node is None:
- raise Exception("No nodes of creative edge pinned", 400)
+ raise Exception("No nodes of creative edge pinned")
return infer, question_node, answer_node, pathfinder
async def bte(task, logger: logging.Logger):
"""Main BTE entrypoint that establishes query workflow."""
- start = time.time()
# given a task, get the message from the db
query_id = task[1]["query_id"]
workflow = json.loads(task[1]["workflow"])
message = await get_message(query_id, logger)
- try:
- infer, question_qnode, answer_qnode, pathfinder = examine_query(message)
- except Exception as e:
- logger.error(e)
- return None, 500
+ infer, question_qnode, answer_qnode, pathfinder = examine_query(message)
if pathfinder:
- # BTE doesn't currently handle Pathfinder queries
- return None, 500
+ raise Exception("BTE does not support Pathfinder type queries.")
if workflow is None:
if infer:
@@ -96,24 +90,44 @@ async def bte(task, logger: logging.Logger):
{"id": "filter_kgraph_orphans"},
]
- await wrap_up_task(STREAM, GROUP, task, workflow, logger)
- logger.info(f"Task took {time.time() - start}")
+ task[1]["workflow"] = json.dumps(workflow)
-async def process_task(task, parent_ctx, logger, limiter):
+async def process_task(task, parent_ctx, logger: logging.Logger, limiter):
+ """Process a given task and ACK in redis."""
+ start = time.time()
span = tracer.start_span(STREAM, context=parent_ctx)
try:
await bte(task, logger)
+ # Always wrap up the task to ACK it in the broker
+ try:
+ await wrap_up_task(STREAM, GROUP, task, logger)
+ except Exception as e:
+ logger.error(f"Task {task[0]}: Failed to wrap up task: {e}")
+ except asyncio.CancelledError:
+ logger.warning(f"Task {task[0]} was cancelled")
+ except Exception as e:
+ logger.error(f"Task {task[0]} failed with unhandled error: {e}", exc_info=True)
+ await handle_task_failure(STREAM, GROUP, task, logger)
finally:
span.end()
limiter.release()
+ logger.info(f"Finished task {task[0]} in {time.time() - start}")
async def poll_for_tasks():
- async for task, parent_ctx, logger, limiter in get_tasks(
- STREAM, GROUP, CONSUMER, TASK_LIMIT
- ):
- asyncio.create_task(process_task(task, parent_ctx, logger, limiter))
+ """On initialization, poll indefinitely for available tasks."""
+ while True:
+ try:
+ async for task, parent_ctx, logger, limiter in get_tasks(
+ STREAM, GROUP, CONSUMER, TASK_LIMIT
+ ):
+ asyncio.create_task(process_task(task, parent_ctx, logger, limiter))
+ except asyncio.CancelledError:
+ logging.info("Poll loop cancelled, shutting down.")
+ except Exception as e:
+ logging.error(f"Error in task polling loop: {e}", exc_info=True)
+ await asyncio.sleep(5) # back off before retrying
if __name__ == "__main__":
diff --git a/workers/bte_lookup/worker.py b/workers/bte_lookup/worker.py
index 7cceab6..5ca5397 100644
--- a/workers/bte_lookup/worker.py
+++ b/workers/bte_lookup/worker.py
@@ -23,7 +23,7 @@
save_message,
)
from shepherd_utils.otel import setup_tracer
-from shepherd_utils.shared import get_tasks, wrap_up_task
+from shepherd_utils.shared import get_tasks, handle_task_failure, wrap_up_task
# Queue name
STREAM = "bte.lookup"
@@ -46,7 +46,7 @@ def examine_query(message):
# this can still fail if the input looks like e.g.:
# "query_graph": None
qedges = message.get("message", {}).get("query_graph", {}).get("edges", {})
- except:
+ except KeyError:
qedges = {}
n_infer_edges = 0
for edge_id in qedges:
@@ -54,9 +54,9 @@ def examine_query(message):
n_infer_edges += 1
pathfinder = n_infer_edges == 3
if n_infer_edges > 1 and n_infer_edges and not pathfinder:
- raise Exception("Only a single infer edge is supported", 400)
+ raise Exception("Only a single infer edge is supported")
if (n_infer_edges > 0) and (n_infer_edges < len(qedges)):
- raise Exception("Mixed infer and lookup queries not supported", 400)
+ raise Exception("Mixed infer and lookup queries not supported")
infer = n_infer_edges == 1
if not infer:
return infer, None, None, pathfinder
@@ -69,9 +69,9 @@ def examine_query(message):
else:
question_node = qnode_id
if answer_node is None:
- raise Exception("Both nodes of creative edge pinned", 400)
+ raise Exception("Both nodes of creative edge pinned")
if question_node is None:
- raise Exception("No nodes of creative edge pinned", 400)
+ raise Exception("No nodes of creative edge pinned")
return infer, question_node, answer_node, pathfinder
@@ -109,10 +109,8 @@ async def run_async_lookup(
async def bte_lookup(task, logger: logging.Logger):
- start = time.time()
# given a task, get the message from the db
query_id = task[1]["query_id"]
- workflow = json.loads(task[1]["workflow"])
message = await get_message(query_id, logger)
parameters = message.get("parameters") or {}
parameters["timeout"] = parameters.get("timeout", settings.lookup_timeout)
@@ -126,14 +124,10 @@ async def bte_lookup(task, logger: logging.Logger):
url=settings.server_url,
)
)
- try:
- infer, question_qnode, answer_qnode, pathfinder = examine_query(message)
- except Exception as e:
- logger.error(e)
- return None, 500
+ infer, question_qnode, answer_qnode, pathfinder = examine_query(message)
if pathfinder:
# BTE currently doesn't handle Pathfinder queries
- return None, 500
+ raise Exception("BTE does not support Pathfinder type queries.")
if not infer:
# Put callback UID and query ID in postgres
@@ -192,9 +186,13 @@ async def bte_lookup(task, logger: logging.Logger):
start_time = time.time()
running_callback_ids = [""]
while time.time() - start_time < MAX_QUERY_TIME:
- # see if there are existing lookups going
- running_callback_ids = await get_running_callbacks(query_id, logger)
- # logger.info(f"Got back {len(running_callback_ids)} running lookups")
+ try:
+ # see if there are existing lookups going
+ running_callback_ids = await get_running_callbacks(query_id, logger)
+ except Exception:
+ # Brief backoff then retry the check rather than giving up
+ await asyncio.sleep(5)
+ continue
# if there are, continue to wait
if len(running_callback_ids) > 0:
await asyncio.sleep(1)
@@ -210,9 +208,6 @@ async def bte_lookup(task, logger: logging.Logger):
)
await cleanup_callbacks(query_id, logger)
- await wrap_up_task(STREAM, GROUP, task, workflow, logger)
- logger.info(f"Finished task {task[0]} in {time.time() - start}")
-
class TemplateGroup(BaseModel):
"""A group of templates to be matched by given criteria."""
@@ -422,19 +417,40 @@ def expand_bte_query(query_dict: dict[str, Any], logger: logging.Logger) -> list
async def process_task(task, parent_ctx, logger, limiter):
+ """Process a given task and ACK in redis."""
+ start = time.time()
span = tracer.start_span(STREAM, context=parent_ctx)
try:
await bte_lookup(task, logger)
+ try:
+ # Always wrap up the task to ACK it in the broker
+ await wrap_up_task(STREAM, GROUP, task, logger)
+ except Exception as e:
+ logger.error(f"Task {task[0]}: Failed to wrap up task: {e}")
+ except asyncio.CancelledError:
+ logger.warning(f"Task {task[0]} was cancelled.")
+ except Exception as e:
+ logger.error(f"Task {task[0]} failed with unhandled error: {e}", exc_info=True)
+ await handle_task_failure(STREAM, GROUP, task, logger)
finally:
span.end()
limiter.release()
+ logger.info(f"Task took {time.time() - start}")
async def poll_for_tasks():
- async for task, parent_ctx, logger, limiter in get_tasks(
- STREAM, GROUP, CONSUMER, TASK_LIMIT
- ):
- asyncio.create_task(process_task(task, parent_ctx, logger, limiter))
+ """On initialization, poll indefinitely for available tasks."""
+ while True:
+ try:
+ async for task, parent_ctx, logger, limiter in get_tasks(
+ STREAM, GROUP, CONSUMER, TASK_LIMIT
+ ):
+ asyncio.create_task(process_task(task, parent_ctx, logger, limiter))
+ except asyncio.CancelledError:
+ logging.info("Poll loop cancelled, shutting down.")
+ except Exception as e:
+ logging.error(f"Error in task polling loop: {e}", exc_info=True)
+ await asyncio.sleep(5) # back off before retrying
if __name__ == "__main__":
diff --git a/workers/example_ara/worker.py b/workers/example_ara/worker.py
index b49be17..03afec0 100644
--- a/workers/example_ara/worker.py
+++ b/workers/example_ara/worker.py
@@ -6,7 +6,7 @@
import time
import uuid
from shepherd_utils.db import get_message
-from shepherd_utils.shared import get_tasks, wrap_up_task
+from shepherd_utils.shared import get_tasks, handle_task_failure, wrap_up_task
from shepherd_utils.otel import setup_tracer
# Queue name
@@ -18,46 +18,58 @@
async def example_ara(task, logger: logging.Logger):
- try:
- start = time.time()
- # given a task, get the message from the db
- logger.info("Getting message from db")
- message = await get_message(task[1]["query_id"], logger)
- # logger.info(message)
- logger.info(task)
-
- workflow = [
- {"id": "example.lookup"},
- {"id": "example.score"},
- {"id": "sort_results_score"},
- {"id": "filter_results_top_n"},
- {"id": "filter_kgraph_orphans"},
- ]
- except Exception as e:
- logger.error(f"Something bad happened! {e}")
- # TODO: gracefully handle worker errors
+ # given a task, get the message from the db
+ logger.info("Getting message from db")
+ message = await get_message(task[1]["query_id"], logger)
+ # logger.info(message)
+ logger.info(task)
- await wrap_up_task(STREAM, GROUP, task, workflow, logger)
+ workflow = [
+ {"id": "example.lookup"},
+ {"id": "example.score"},
+ {"id": "sort_results_score"},
+ {"id": "filter_results_top_n"},
+ {"id": "filter_kgraph_orphans"},
+ ]
- logger.info(f"Finished task {task[0]} in {time.time() - start}")
+ task[1]["workflow"] = workflow
-async def process_task(task, parent_ctx, logger, limiter):
+async def process_task(task, parent_ctx, logger: logging.Logger, limiter):
+ """Process a given task and ACK in redis."""
+ start = time.time()
span = tracer.start_span(STREAM, context=parent_ctx)
try:
await example_ara(task, logger)
+ # Always wrap up the task to ACK it in the broker
+ try:
+ await wrap_up_task(STREAM, GROUP, task, logger)
+ except Exception as e:
+ logger.error(f"Task {task[0]}: Failed to wrap up task: {e}")
+ except asyncio.CancelledError:
+ logger.warning(f"Task {task[0]} was cancelled")
except Exception as e:
- logger.error(f"Something went wrong: {e}")
+ logger.error(f"Task {task[0]} failed with unhandled error: {e}", exc_info=True)
+ await handle_task_failure(STREAM, GROUP, task, logger)
finally:
span.end()
limiter.release()
+ logger.info(f"Finished task {task[0]} in {time.time() - start}")
async def poll_for_tasks():
- async for task, parent_ctx, logger, limiter in get_tasks(
- STREAM, GROUP, CONSUMER, TASK_LIMIT
- ):
- asyncio.create_task(process_task(task, parent_ctx, logger, limiter))
+ """On initialization, poll indefinitely for available tasks."""
+ while True:
+ try:
+ async for task, parent_ctx, logger, limiter in get_tasks(
+ STREAM, GROUP, CONSUMER, TASK_LIMIT
+ ):
+ asyncio.create_task(process_task(task, parent_ctx, logger, limiter))
+ except asyncio.CancelledError:
+ logging.info("Poll loop cancelled, shutting down.")
+ except Exception as e:
+ logging.error(f"Error in task polling loop: {e}", exc_info=True)
+ await asyncio.sleep(5) # back off before retrying
if __name__ == "__main__":
diff --git a/workers/example_lookup/worker.py b/workers/example_lookup/worker.py
index 735c2e5..a0832e7 100644
--- a/workers/example_lookup/worker.py
+++ b/workers/example_lookup/worker.py
@@ -9,14 +9,16 @@
import httpx
+from shepherd_utils.config import settings
from shepherd_utils.db import (
add_callback_id,
+ cleanup_callbacks,
get_message,
get_running_callbacks,
save_message,
)
from shepherd_utils.otel import setup_tracer
-from shepherd_utils.shared import get_tasks, wrap_up_task
+from shepherd_utils.shared import get_tasks, handle_task_failure, wrap_up_task
# Queue name
STREAM = "example.lookup"
@@ -31,15 +33,11 @@ async def example_lookup(task, logger: logging.Logger):
Just sends a test response back to the server callback endpoint.
"""
- start = time.time()
# given a task, get the message from the db
query_id = task[1]["query_id"]
- workflow = json.loads(task[1]["workflow"])
- try:
- message = await get_message(query_id, logger)
- except Exception as e:
- logger.error(f"Task {task[0]}: Failed to get message for query {query_id}: {e}")
- raise
+ message = await get_message(query_id, logger)
+ parameters = message.get("parameters") or {}
+ parameters["timeout"] = parameters.get("timeout", settings.lookup_timeout)
# Do query expansion or whatever lookup process
# We're going to stub a response
@@ -93,57 +91,56 @@ async def example_lookup(task, logger: logging.Logger):
# this worker might have a timeout set for if the lookups don't finish within a
# certain amount of time
- MAX_QUERY_TIME = 300
+ MAX_QUERY_TIME = message["parameters"]["timeout"]
start_time = time.time()
- try:
- while time.time() - start_time < MAX_QUERY_TIME:
- try:
- # see if there are existing lookups going
- running_callback_ids = await get_running_callbacks(query_id, logger)
- except Exception as e:
- logger.error(f"Task {task[0]}: Failed to check running callbacks: {e}")
- # Brief backoff then retry the check rather than giving up
- await asyncio.sleep(5)
- continue
- # logger.info(f"Got back {len(running_callback_ids)} running lookups")
- # if there aren't, lookup is complete and we need to pass on to next
- # workflow operation
- if len(running_callback_ids) == 0:
- break
-
- await asyncio.sleep(1)
- except asyncio.CancelledError:
- logger.warning(f"Task {task[0]}: Cancelled while waiting for callbacks.")
-
- # Always wrap up the task to ACK it in the broker
- try:
- await wrap_up_task(STREAM, GROUP, task, workflow, logger)
- except Exception as e:
- logger.error(f"Task {task[0]}: Failed to wrap up task: {e}")
- raise
- logger.info(f"Finished task {task[0]} in {time.time() - start}")
+ running_callback_ids = [""]
+ while time.time() - start_time < MAX_QUERY_TIME:
+ try:
+ # see if there are existing lookups going
+ running_callback_ids = await get_running_callbacks(query_id, logger)
+ except Exception:
+ # Brief backoff then retry the check rather than giving up
+ await asyncio.sleep(5)
+ continue
+ # if there aren't, lookup is complete and we need to pass on to next
+ # workflow operation
+ if len(running_callback_ids) == 0:
+ break
+
+ await asyncio.sleep(1)
+
+ if time.time() - start_time > MAX_QUERY_TIME:
+ logger.warning(
+ f"Timed out getting lookup callbacks. {len(running_callback_ids)} queries were still running..."
+ )
+ # logger.warning(f"Running callbacks: {running_callback_ids}")
+ await cleanup_callbacks(query_id, logger)
async def process_task(task, parent_ctx, logger: logging.Logger, limiter):
+ """Process a given task and ACK in redis."""
+ start = time.time()
span = tracer.start_span(STREAM, context=parent_ctx)
try:
await example_lookup(task, logger)
+ # Always wrap up the task to ACK it in the broker
+ try:
+ await wrap_up_task(STREAM, GROUP, task, logger)
+ except Exception as e:
+ logger.error(f"Task {task[0]}: Failed to wrap up task: {e}")
except asyncio.CancelledError:
logger.warning(f"Task {task[0]} was cancelled")
except Exception as e:
logger.error(f"Task {task[0]} failed with unhandled error: {e}", exc_info=True)
- # Attempt to ACK the task so it doesn't get redelivered forever
- try:
- workflow = json.loads(task[1])["workflow"]
- await wrap_up_task(STREAM, GROUP, task, workflow, logger)
- except Exception as wrap_err:
- logger.error(f"Task {task[0]}: Also failed wrap up after error: {wrap_err}")
+ await handle_task_failure(STREAM, GROUP, task, logger)
finally:
span.end()
limiter.release()
+ logger.info(f"Finished task {task[0]} in {time.time() - start}")
async def poll_for_tasks():
+ """On initialization, poll indefinitely for available tasks."""
while True:
try:
async for task, parent_ctx, logger, limiter in get_tasks(
diff --git a/workers/example_score/worker.py b/workers/example_score/worker.py
index ee9f7bb..08ba9fb 100644
--- a/workers/example_score/worker.py
+++ b/workers/example_score/worker.py
@@ -7,7 +7,7 @@
import time
import uuid
from shepherd_utils.db import get_message, save_message
-from shepherd_utils.shared import get_tasks, wrap_up_task
+from shepherd_utils.shared import get_tasks, handle_task_failure, wrap_up_task
from shepherd_utils.otel import setup_tracer
# Queue name
@@ -19,36 +19,54 @@
async def example_score(task, logger: logging.Logger):
- start = time.time()
+ """Do a very random score."""
# given a task, get the message from the db
response_id = task[1]["response_id"]
- workflow = json.loads(task[1]["workflow"])
message = await get_message(response_id, logger)
# give a random score to all results
- for result in message["message"]["results"]:
+ for result in message["message"].get("results", []):
for analysis in result["analyses"]:
analysis["score"] = random.random()
await save_message(response_id, message, logger)
- await wrap_up_task(STREAM, GROUP, task, workflow, logger)
- logger.info(f"Finished task {task[0]} in {time.time() - start}")
-async def process_task(task, parent_ctx, logger, limiter):
+async def process_task(task, parent_ctx, logger: logging.Logger, limiter):
+ """Process a given task and ACK in redis."""
+ start = time.time()
span = tracer.start_span(STREAM, context=parent_ctx)
try:
await example_score(task, logger)
+ # Always wrap up the task to ACK it in the broker
+ try:
+ await wrap_up_task(STREAM, GROUP, task, logger)
+ except Exception as e:
+ logger.error(f"Task {task[0]}: Failed to wrap up task: {e}")
+ except asyncio.CancelledError:
+ logger.warning(f"Task {task[0]} was cancelled")
+ except Exception as e:
+ logger.error(f"Task {task[0]} failed with unhandled error: {e}", exc_info=True)
+ await handle_task_failure(STREAM, GROUP, task, logger)
finally:
span.end()
limiter.release()
+ logger.info(f"Finished task {task[0]} in {time.time() - start}")
async def poll_for_tasks():
- async for task, parent_ctx, logger, limiter in get_tasks(
- STREAM, GROUP, CONSUMER, TASK_LIMIT
- ):
- asyncio.create_task(process_task(task, parent_ctx, logger, limiter))
+ """On initialization, poll indefinitely for available tasks."""
+ while True:
+ try:
+ async for task, parent_ctx, logger, limiter in get_tasks(
+ STREAM, GROUP, CONSUMER, TASK_LIMIT
+ ):
+ asyncio.create_task(process_task(task, parent_ctx, logger, limiter))
+ except asyncio.CancelledError:
+ logging.info("Poll loop cancelled, shutting down.")
+ except Exception as e:
+ logging.error(f"Error in task polling loop: {e}", exc_info=True)
+ await asyncio.sleep(5) # back off before retrying
if __name__ == "__main__":
diff --git a/workers/filter_analyses_top_n/worker.py b/workers/filter_analyses_top_n/worker.py
index 92e0ed7..9015f40 100644
--- a/workers/filter_analyses_top_n/worker.py
+++ b/workers/filter_analyses_top_n/worker.py
@@ -6,7 +6,7 @@
import time
import uuid
from shepherd_utils.db import get_message, save_message, get_query_state
-from shepherd_utils.shared import get_tasks, wrap_up_task
+from shepherd_utils.shared import get_tasks, handle_task_failure, wrap_up_task
from shepherd_utils.otel import setup_tracer
# Queue name
@@ -18,7 +18,7 @@
async def filter_analyses_top_n(task, logger: logging.Logger):
- start = time.time()
+ """Filter results analyses to top n."""
# given a task, get the message from the db
response_id = task[1]["response_id"]
workflow = json.loads(task[1]["workflow"])
@@ -29,36 +29,49 @@ async def filter_analyses_top_n(task, logger: logging.Logger):
logger.error(f"Unable to find operation {STREAM} in workflow")
raise Exception(f"Operation {STREAM} is not in workflow")
n = current_op.get("max_analyses", 1000)
- try:
- for ind, result in enumerate(message["message"]["results"]):
- message["message"]["results"][ind]["analyses"] = result["analyses"][:n]
- except KeyError as e:
- # can't find the right structure of message
- logger.error(f"Error filtering results: {e}")
- return message, 400
+ for ind, result in enumerate(results):
+ message["message"]["results"][ind]["analyses"] = result["analyses"][:n]
logger.info("Returning filtered results.")
# save merged message back to db
await save_message(response_id, message, logger)
- await wrap_up_task(STREAM, GROUP, task, workflow, logger)
- logger.info(f"Finished task {task[0]} in {time.time() - start}")
-
-async def process_task(task, parent_ctx, logger, limiter):
+async def process_task(task, parent_ctx, logger: logging.Logger, limiter):
+ """Process a given task and ACK in redis."""
+ start = time.time()
span = tracer.start_span(STREAM, context=parent_ctx)
try:
await filter_analyses_top_n(task, logger)
+ # Always wrap up the task to ACK it in the broker
+ try:
+ await wrap_up_task(STREAM, GROUP, task, logger)
+ except Exception as e:
+ logger.error(f"Task {task[0]}: Failed to wrap up task: {e}")
+ except asyncio.CancelledError:
+ logger.warning(f"Task {task[0]} was cancelled")
+ except Exception as e:
+ logger.error(f"Task {task[0]} failed with unhandled error: {e}", exc_info=True)
+ await handle_task_failure(STREAM, GROUP, task, logger)
finally:
span.end()
limiter.release()
+ logger.info(f"Finished task {task[0]} in {time.time() - start}")
async def poll_for_tasks():
- async for task, parent_ctx, logger, limiter in get_tasks(
- STREAM, GROUP, CONSUMER, TASK_LIMIT
- ):
- asyncio.create_task(process_task(task, parent_ctx, logger, limiter))
+ """On initialization, poll indefinitely for available tasks."""
+ while True:
+ try:
+ async for task, parent_ctx, logger, limiter in get_tasks(
+ STREAM, GROUP, CONSUMER, TASK_LIMIT
+ ):
+ asyncio.create_task(process_task(task, parent_ctx, logger, limiter))
+ except asyncio.CancelledError:
+ logging.info("Poll loop cancelled, shutting down.")
+ except Exception as e:
+ logging.error(f"Error in task polling loop: {e}", exc_info=True)
+ await asyncio.sleep(5) # back off before retrying
if __name__ == "__main__":
diff --git a/workers/filter_kgraph_orphans/worker.py b/workers/filter_kgraph_orphans/worker.py
index 3c1140b..ead4723 100644
--- a/workers/filter_kgraph_orphans/worker.py
+++ b/workers/filter_kgraph_orphans/worker.py
@@ -11,6 +11,7 @@
from shepherd_utils.shared import (
filter_kgraph_orphans,
get_tasks,
+ handle_task_failure,
wrap_up_task,
)
@@ -27,34 +28,50 @@ async def do_filter_kgraph_orphans(task, logger: logging.Logger):
Given a TRAPI message, remove all kgraph nodes and edges that aren't referenced
in any results.
"""
- start = time.time()
# given a task, get the message from the db
response_id = task[1]["response_id"]
- workflow = json.loads(task[1]["workflow"])
message = await get_message(response_id, logger)
filter_kgraph_orphans(message, logger)
# save merged message back to db
await save_message(response_id, message, logger)
- await wrap_up_task(STREAM, GROUP, task, workflow, logger)
- logger.info(f"Finished task {task[0]} in {time.time() - start}")
-
-async def process_task(task, parent_ctx, logger, limiter):
+async def process_task(task, parent_ctx, logger: logging.Logger, limiter):
+ """Process a given task and ACK in redis."""
+ start = time.time()
span = tracer.start_span(STREAM, context=parent_ctx)
try:
await do_filter_kgraph_orphans(task, logger)
+ # Always wrap up the task to ACK it in the broker
+ try:
+ await wrap_up_task(STREAM, GROUP, task, logger)
+ except Exception as e:
+ logger.error(f"Task {task[0]}: Failed to wrap up task: {e}")
+ except asyncio.CancelledError:
+ logger.warning(f"Task {task[0]} was cancelled")
+ except Exception as e:
+ logger.error(f"Task {task[0]} failed with unhandled error: {e}", exc_info=True)
+ await handle_task_failure(STREAM, GROUP, task, logger)
finally:
span.end()
limiter.release()
+ logger.info(f"Finished task {task[0]} in {time.time() - start}")
async def poll_for_tasks():
- async for task, parent_ctx, logger, limiter in get_tasks(
- STREAM, GROUP, CONSUMER, TASK_LIMIT
- ):
- asyncio.create_task(process_task(task, parent_ctx, logger, limiter))
+ """On initialization, poll indefinitely for available tasks."""
+ while True:
+ try:
+ async for task, parent_ctx, logger, limiter in get_tasks(
+ STREAM, GROUP, CONSUMER, TASK_LIMIT
+ ):
+ asyncio.create_task(process_task(task, parent_ctx, logger, limiter))
+ except asyncio.CancelledError:
+ logging.info("Poll loop cancelled, shutting down.")
+ except Exception as e:
+ logging.error(f"Error in task polling loop: {e}", exc_info=True)
+ await asyncio.sleep(5) # back off before retrying
if __name__ == "__main__":
diff --git a/workers/filter_results_top_n/worker.py b/workers/filter_results_top_n/worker.py
index bb83378..b37df9e 100644
--- a/workers/filter_results_top_n/worker.py
+++ b/workers/filter_results_top_n/worker.py
@@ -6,7 +6,7 @@
import time
import uuid
from shepherd_utils.db import get_message, save_message, get_query_state
-from shepherd_utils.shared import get_tasks, wrap_up_task
+from shepherd_utils.shared import get_tasks, handle_task_failure, wrap_up_task
from shepherd_utils.otel import setup_tracer
# Queue name
@@ -18,7 +18,6 @@
async def filter_results_top_n(task, logger: logging.Logger):
- start = time.time()
# given a task, get the message from the db
response_id = task[1]["response_id"]
workflow = json.loads(task[1]["workflow"])
@@ -29,35 +28,49 @@ async def filter_results_top_n(task, logger: logging.Logger):
logger.error(f"Unable to find operation {STREAM} in workflow")
raise Exception(f"Operation {STREAM} is not in workflow")
n = current_op.get("max_results", 500)
- try:
- message["message"]["results"] = results[:n]
- except KeyError as e:
- # can't find the right structure of message
- logger.error(f"Error filtering results: {e}")
- return message, 400
+
+ message["message"]["results"] = results[:n]
logger.info("Returning filtered results.")
# save merged message back to db
await save_message(response_id, message, logger)
- await wrap_up_task(STREAM, GROUP, task, workflow, logger)
- logger.info(f"Finished task {task[0]} in {time.time() - start}")
-
-async def process_task(task, parent_ctx, logger, limiter):
+async def process_task(task, parent_ctx, logger: logging.Logger, limiter):
+ """Process a given task and ACK in redis."""
+ start = time.time()
span = tracer.start_span(STREAM, context=parent_ctx)
try:
await filter_results_top_n(task, logger)
+ # Always wrap up the task to ACK it in the broker
+ try:
+ await wrap_up_task(STREAM, GROUP, task, logger)
+ except Exception as e:
+ logger.error(f"Task {task[0]}: Failed to wrap up task: {e}")
+ except asyncio.CancelledError:
+ logger.warning(f"Task {task[0]} was cancelled")
+ except Exception as e:
+ logger.error(f"Task {task[0]} failed with unhandled error: {e}", exc_info=True)
+ await handle_task_failure(STREAM, GROUP, task, logger)
finally:
span.end()
limiter.release()
+ logger.info(f"Finished task {task[0]} in {time.time() - start}")
async def poll_for_tasks():
- async for task, parent_ctx, logger, limiter in get_tasks(
- STREAM, GROUP, CONSUMER, TASK_LIMIT
- ):
- asyncio.create_task(process_task(task, parent_ctx, logger, limiter))
+ """On initialization, poll indefinitely for available tasks."""
+ while True:
+ try:
+ async for task, parent_ctx, logger, limiter in get_tasks(
+ STREAM, GROUP, CONSUMER, TASK_LIMIT
+ ):
+ asyncio.create_task(process_task(task, parent_ctx, logger, limiter))
+ except asyncio.CancelledError:
+ logging.info("Poll loop cancelled, shutting down.")
+ except Exception as e:
+ logging.error(f"Error in task polling loop: {e}", exc_info=True)
+ await asyncio.sleep(5) # back off before retrying
if __name__ == "__main__":
diff --git a/workers/finish_query/worker.py b/workers/finish_query/worker.py
index 842d94c..d03ceda 100644
--- a/workers/finish_query/worker.py
+++ b/workers/finish_query/worker.py
@@ -23,13 +23,16 @@
CONSUMER = str(uuid.uuid4())[:8]
TASK_LIMIT = 100
tracer = setup_tracer(STREAM)
+CALLBACK_RETRIES = 3
async def finish_query(task, logger: logging.Logger):
+ """Do all the wrap up necessary for a query."""
start = time.time()
# given a task, get the message from the db
query_id = task[1]["query_id"]
response_id = task[1]["response_id"]
+ status = task[1].get("status", "OK")
query_state = await get_query_state(query_id, logger)
if query_state is None:
@@ -41,37 +44,59 @@ async def finish_query(task, logger: logging.Logger):
message = await get_message(response_id, logger)
logs = await get_logs(response_id, logger)
message["logs"] = logs
- try:
- async with httpx.AsyncClient(timeout=60) as client:
- response = await client.post(
- callback_url,
- json=message,
- )
- response.raise_for_status()
- logger.info(f"Sent response back to {callback_url}")
- except Exception as e:
- logger.error(f"Failed to send callback to {callback_url}: {e}")
+ for attempt in range(CALLBACK_RETRIES):
+ try:
+ async with httpx.AsyncClient(timeout=120) as client:
+ response = await client.post(
+ callback_url,
+ json=message,
+ )
+ response.raise_for_status()
+ logger.info(f"Sent response back to {callback_url}")
+ break
+ except Exception as e:
+ logger.error(f"Failed to send callback to {callback_url}: {e}")
+ await asyncio.sleep(1 * (2**attempt))
- await set_query_completed(query_id, "OK", logger)
+ await set_query_completed(query_id, status, logger)
- await mark_task_as_complete(STREAM, GROUP, task[0], logger)
logger.info(f"Finished task {task[0]} in {time.time() - start}")
-async def process_task(task, parent_ctx, logger, limiter):
+async def process_task(task, parent_ctx, logger: logging.Logger, limiter):
+ """Process a given task and ACK in redis."""
+ start = time.time()
span = tracer.start_span(STREAM, context=parent_ctx)
try:
await finish_query(task, logger)
+ except asyncio.CancelledError:
+ logger.warning(f"Task {task[0]} was cancelled")
+ except Exception as e:
+ logger.error(f"Task {task[0]} failed with unhandled error: {e}", exc_info=True)
finally:
+ # Always wrap up the task to ACK it in the broker
+ try:
+ await mark_task_as_complete(STREAM, GROUP, task[0], logger)
+ except Exception as e:
+ logger.error(f"Task {task[0]}: Failed to wrap up task: {e}")
span.end()
limiter.release()
+ logger.info(f"Finished task {task[0]} in {time.time() - start}")
async def poll_for_tasks():
- async for task, parent_ctx, logger, limiter in get_tasks(
- STREAM, GROUP, CONSUMER, TASK_LIMIT
- ):
- asyncio.create_task(process_task(task, parent_ctx, logger, limiter))
+ """On initialization, poll indefinitely for available tasks."""
+ while True:
+ try:
+ async for task, parent_ctx, logger, limiter in get_tasks(
+ STREAM, GROUP, CONSUMER, TASK_LIMIT
+ ):
+ asyncio.create_task(process_task(task, parent_ctx, logger, limiter))
+ except asyncio.CancelledError:
+ logging.info("Poll loop cancelled, shutting down.")
+ except Exception as e:
+ logging.error(f"Error in task polling loop: {e}", exc_info=True)
+ await asyncio.sleep(5) # back off before retrying
if __name__ == "__main__":
diff --git a/workers/gandalf/requirements.txt b/workers/gandalf/requirements.txt
index 227b37b..64cd655 100644
--- a/workers/gandalf/requirements.txt
+++ b/workers/gandalf/requirements.txt
@@ -1 +1,2 @@
-gandalf-csr>=0.1.8
+gandalf-csr>=0.1.11
+
diff --git a/workers/gandalf/worker.py b/workers/gandalf/worker.py
index a8c539d..a08c4c9 100644
--- a/workers/gandalf/worker.py
+++ b/workers/gandalf/worker.py
@@ -82,75 +82,92 @@ async def poll_for_tasks(graph: CSRGraph, bmt: Toolkit):
loop = asyncio.get_running_loop()
executor = ThreadPoolExecutor(max_workers=1)
- async for task, parent_ctx, task_logger, limiter in get_tasks(
- STREAM, GROUP, CONSUMER, TASK_LIMIT
- ):
- span = tracer.start_span(STREAM, context=parent_ctx)
- start = time.time()
+ while True:
try:
- task_logger.info("Got task for Gandalf")
- response_id = task[1]["response_id"]
- callback_id = task[1]["callback_id"]
- target = task[1].get("target", "unknown")
-
- task_logger.info("Getting message")
- message = await get_message(callback_id, task_logger)
- if message is None:
- task_logger.error(f"Failed to get {response_id} for lookup.")
- continue
-
- query_id = await get_callback_query_id(callback_id, task_logger)
- task_logger.info(f"Got original query id: {query_id}")
- if query_id is None:
- task_logger.error("Failed to get original query id.")
- continue
-
- query_state = await get_query_state(query_id, task_logger)
- if query_state is None:
- task_logger.error("Failed to get query state.")
- continue
-
- response_id = query_state[7]
-
- lookup_response = await loop.run_in_executor(
- executor,
- gandalf_lookup,
- graph,
- bmt,
- message,
- task_logger,
- )
-
- if DEBUG_RESPONSES and len(lookup_response["message"]["results"]) > 0:
- debug_dir = Path("debug")
- debug_dir.mkdir(exist_ok=True)
- debug_path = debug_dir / f"{query_id}_{callback_id}_response.json"
- with open(debug_path, "w", encoding="utf-8") as f:
- json.dump(lookup_response, f, indent=2)
-
- task_logger.info(f"Saving callback {callback_id} to redis")
- await save_message(callback_id, lookup_response, task_logger)
- task_logger.info(f"Saved callback {callback_id} to redis")
-
- await add_task(
- "merge_message",
- {
- "target": target,
- "query_id": query_id,
- "response_id": response_id,
- "callback_id": callback_id,
- "log_level": task[1].get("log_level", 20),
- "otel": task[1]["otel"],
- },
- task_logger,
- )
- except Exception:
- task_logger.exception(f"Task {task[0]} failed")
- finally:
- await mark_task_as_complete(STREAM, GROUP, task[0], logger)
- task_logger.info(f"Finished task {task[0]} in {time.time() - start:.2f}s")
- span.end()
- limiter.release()
+ async for task, parent_ctx, task_logger, limiter in get_tasks(
+ STREAM, GROUP, CONSUMER, TASK_LIMIT
+ ):
+ span = tracer.start_span(STREAM, context=parent_ctx)
+ start = time.time()
+ try:
+ task_logger.info("Got task for Gandalf")
+ response_id = task[1]["response_id"]
+ callback_id = task[1]["callback_id"]
+ target = task[1].get("target", "unknown")
+
+ task_logger.info("Getting message")
+ message = await get_message(callback_id, task_logger)
+ if message is None:
+ task_logger.error(f"Failed to get {response_id} for lookup.")
+ continue
+
+ query_id = await get_callback_query_id(callback_id, task_logger)
+ task_logger.info(f"Got original query id: {query_id}")
+ if query_id is None:
+ task_logger.error("Failed to get original query id.")
+ continue
+
+ query_state = await get_query_state(query_id, task_logger)
+ if query_state is None:
+ task_logger.error("Failed to get query state.")
+ continue
+
+ response_id = query_state[7]
+
+ lookup_response = await loop.run_in_executor(
+ executor,
+ gandalf_lookup,
+ graph,
+ bmt,
+ message,
+ task_logger,
+ )
+
+ if (
+ DEBUG_RESPONSES
+ and len(lookup_response["message"]["results"]) > 0
+ ):
+ debug_dir = Path("debug")
+ debug_dir.mkdir(exist_ok=True)
+ debug_path = (
+ debug_dir / f"{query_id}_{callback_id}_response.json"
+ )
+ with open(debug_path, "w", encoding="utf-8") as f:
+ json.dump(lookup_response, f, indent=2)
+
+ task_logger.info(f"Saving callback {callback_id} to redis")
+ await save_message(callback_id, lookup_response, task_logger)
+ task_logger.info(f"Saved callback {callback_id} to redis")
+
+ await add_task(
+ "merge_message",
+ {
+ "target": target,
+ "query_id": query_id,
+ "response_id": response_id,
+ "callback_id": callback_id,
+ "log_level": task[1].get("log_level", 20),
+ "otel": task[1]["otel"],
+ },
+ task_logger,
+ )
+ except Exception:
+ task_logger.exception(f"Task {task[0]} failed")
+ finally:
+ try:
+ await mark_task_as_complete(STREAM, GROUP, task[0], logger)
+ except Exception as e:
+ logger.error(f"Task {task[0]}: Failed to wrap up task: {e}")
+ task_logger.info(
+ f"Finished task {task[0]} in {time.time() - start:.2f}s"
+ )
+ span.end()
+ limiter.release()
+ except asyncio.CancelledError:
+ logging.info("Poll loop cancelled, shutting down.")
+ except Exception as e:
+ logging.error(f"Error in task polling loop: {e}", exc_info=True)
+ await asyncio.sleep(5) # back off before retrying
if __name__ == "__main__":
diff --git a/workers/gandalf_rehydrate/requirements.txt b/workers/gandalf_rehydrate/requirements.txt
index 227b37b..64cd655 100644
--- a/workers/gandalf_rehydrate/requirements.txt
+++ b/workers/gandalf_rehydrate/requirements.txt
@@ -1 +1,2 @@
-gandalf-csr>=0.1.8
+gandalf-csr>=0.1.11
+
diff --git a/workers/gandalf_rehydrate/worker.py b/workers/gandalf_rehydrate/worker.py
index c51b985..dbffeae 100644
--- a/workers/gandalf_rehydrate/worker.py
+++ b/workers/gandalf_rehydrate/worker.py
@@ -18,7 +18,7 @@
save_message,
)
from shepherd_utils.otel import setup_tracer
-from shepherd_utils.shared import get_tasks, wrap_up_task
+from shepherd_utils.shared import get_tasks, handle_task_failure, wrap_up_task
# Queue name
STREAM = "gandalf.rehydrate"
@@ -80,47 +80,64 @@ async def poll_for_tasks(graph: CSRGraph, bmt: Toolkit):
loop = asyncio.get_running_loop()
executor = ThreadPoolExecutor(max_workers=1)
- async for task, parent_ctx, task_logger, limiter in get_tasks(
- STREAM, GROUP, CONSUMER, TASK_LIMIT
- ):
- span = tracer.start_span(STREAM, context=parent_ctx)
- start = time.time()
- task_logger.info("Got task for Gandalf Rehydration")
- workflow = json.loads(task[1]["workflow"])
- response_id = task[1]["response_id"]
+ while True:
try:
- message = await get_message(response_id, task_logger)
- if message is None:
- task_logger.error(f"Failed to get {response_id} for rehydration.")
- continue
-
- hydrated_response = await loop.run_in_executor(
- executor,
- gandalf_rehydration,
- graph,
- bmt,
- message,
- task_logger,
- )
-
- if DEBUG_RESPONSES and len(hydrated_response["message"]["results"]) > 0:
- debug_dir = Path("debug")
- debug_dir.mkdir(exist_ok=True)
- debug_path = debug_dir / f"{response_id}_response.json"
- with open(debug_path, "w", encoding="utf-8") as f:
- json.dump(hydrated_response, f, indent=2)
-
- task_logger.info(f"Saving response {response_id} to redis")
- await save_message(response_id, hydrated_response, task_logger)
- task_logger.info(f"Saved response {response_id} to redis")
-
- except Exception:
- task_logger.exception(f"Task {task[0]} failed")
- finally:
- await wrap_up_task(STREAM, GROUP, task, workflow, logger)
- task_logger.info(f"Finished task {task[0]} in {time.time() - start:.2f}s")
- span.end()
- limiter.release()
+ async for task, parent_ctx, task_logger, limiter in get_tasks(
+ STREAM, GROUP, CONSUMER, TASK_LIMIT
+ ):
+ span = tracer.start_span(STREAM, context=parent_ctx)
+ start = time.time()
+ task_logger.info("Got task for Gandalf Rehydration")
+ workflow = json.loads(task[1]["workflow"])
+ response_id = task[1]["response_id"]
+ try:
+ message = await get_message(response_id, task_logger)
+ if message is None:
+ task_logger.error(
+ f"Failed to get {response_id} for rehydration."
+ )
+ continue
+
+ hydrated_response = await loop.run_in_executor(
+ executor,
+ gandalf_rehydration,
+ graph,
+ bmt,
+ message,
+ task_logger,
+ )
+
+ if (
+ DEBUG_RESPONSES
+ and len(hydrated_response["message"]["results"]) > 0
+ ):
+ debug_dir = Path("debug")
+ debug_dir.mkdir(exist_ok=True)
+ debug_path = debug_dir / f"{response_id}_response.json"
+ with open(debug_path, "w", encoding="utf-8") as f:
+ json.dump(hydrated_response, f, indent=2)
+
+ await save_message(response_id, hydrated_response, task_logger)
+ # Always wrap up the task to ACK it in the broker
+ try:
+ await wrap_up_task(STREAM, GROUP, task, logger)
+ except Exception as e:
+ logger.error(f"Task {task[0]}: Failed to wrap up task: {e}")
+
+ except Exception:
+ task_logger.exception(f"Task {task[0]} failed")
+ await handle_task_failure(STREAM, GROUP, task, logger)
+ finally:
+ task_logger.info(
+ f"Finished task {task[0]} in {time.time() - start:.2f}s"
+ )
+ span.end()
+ limiter.release()
+ except asyncio.CancelledError:
+ logging.info("Poll loop cancelled, shutting down.")
+ except Exception as e:
+ logging.error(f"Error in task polling loop: {e}", exc_info=True)
+ await asyncio.sleep(5) # back off before retrying
if __name__ == "__main__":
diff --git a/workers/merge_message/worker.py b/workers/merge_message/worker.py
index 9c551b1..1ffd246 100644
--- a/workers/merge_message/worker.py
+++ b/workers/merge_message/worker.py
@@ -597,64 +597,85 @@ async def poll_for_tasks():
cpu_count = cpu_count if cpu_count is not None else 1
cpu_count = min(cpu_count, TASK_LIMIT)
executor = ProcessPoolExecutor(max_workers=cpu_count)
- async for task, parent_ctx, logger, limiter in get_tasks(
- STREAM, GROUP, CONSUMER, cpu_count
- ):
- span = tracer.start_span(STREAM, context=parent_ctx)
- query_id = task[1]["query_id"]
- response_id = task[1]["response_id"]
- callback_id = task[1]["callback_id"]
- target = task[1]["target"]
- got_lock = await acquire_lock(response_id, CONSUMER, logger)
- if got_lock:
- logger.info(f"[{callback_id}] Obtained lock.")
-
- # given a task, get the message from the db
- original_query = await get_message(query_id, logger)
- if original_query is None:
- logger.error(
- f"Failed to get original query for {query_id}. Discarding callback response."
- )
- await remove_lock(response_id, CONSUMER, logger)
- await remove_callback_id(callback_id, logger)
- limiter.release()
- await mark_task_as_complete(STREAM, GROUP, task[0], logger)
- span.end()
- continue
- original_query_graph = original_query["message"]["query_graph"]
- callback_response = await get_message(callback_id, logger)
- lock_time = time.time()
- original_response = await get_message(response_id, logger)
- # do message merging
- try:
- merged_message = await loop.run_in_executor(
- executor,
- merge_messages,
- target,
- original_query_graph,
- original_response,
- callback_response,
- logger,
- )
- # save merged message back to db
- await save_message(response_id, merged_message, logger)
- except Exception:
- logger.error(
- f"[{callback_id}] Error merging message: {traceback.format_exc()}"
- )
- logger.info(
- f"[{callback_id}] Kept the lock for {time.time() - lock_time} seconds"
- )
- # remove lock so others can now modify message
- await remove_lock(response_id, CONSUMER, logger)
- else:
- logger.error(
- f"Failed to obtain lock for {query_id}. Discarding callback response."
- )
- await remove_callback_id(callback_id, logger)
- limiter.release()
- await mark_task_as_complete(STREAM, GROUP, task[0], logger)
- span.end()
+ while True:
+ try:
+ async for task, parent_ctx, logger, limiter in get_tasks(
+ STREAM, GROUP, CONSUMER, cpu_count
+ ):
+ start = time.time()
+ span = tracer.start_span(STREAM, context=parent_ctx)
+ try:
+ query_id = task[1]["query_id"]
+ response_id = task[1]["response_id"]
+ callback_id = task[1]["callback_id"]
+ target = task[1]["target"]
+ got_lock = await acquire_lock(response_id, CONSUMER, logger)
+ if got_lock:
+ logger.info(f"[{callback_id}] Obtained lock.")
+
+ # given a task, get the message from the db
+ original_query = await get_message(query_id, logger)
+ if original_query is None:
+ logger.error(
+ f"Failed to get original query for {query_id}. Discarding callback response."
+ )
+ await remove_lock(response_id, CONSUMER, logger)
+ await remove_callback_id(callback_id, logger)
+ limiter.release()
+ await mark_task_as_complete(STREAM, GROUP, task[0], logger)
+ span.end()
+ continue
+ original_query_graph = original_query["message"]["query_graph"]
+ callback_response = await get_message(callback_id, logger)
+ lock_time = time.time()
+ original_response = await get_message(response_id, logger)
+ # do message merging
+ try:
+ merged_message = await loop.run_in_executor(
+ executor,
+ merge_messages,
+ target,
+ original_query_graph,
+ original_response,
+ callback_response,
+ logger,
+ )
+ # save merged message back to db
+ await save_message(response_id, merged_message, logger)
+ except Exception:
+ logger.error(
+ f"[{callback_id}] Error merging message: {traceback.format_exc()}"
+ )
+ logger.info(
+ f"[{callback_id}] Kept the lock for {time.time() - lock_time} seconds"
+ )
+ # remove lock so others can now modify message
+ await remove_lock(response_id, CONSUMER, logger)
+ await remove_callback_id(callback_id, logger)
+ else:
+ logger.error(
+ f"Failed to obtain lock for {query_id}. Discarding callback response."
+ )
+ except Exception as e:
+ logger.error(
+ f"Task {task[0]} failed with unhandled error: {e}",
+ exc_info=True,
+ )
+ finally:
+ try:
+ await mark_task_as_complete(STREAM, GROUP, task[0], logger)
+ except Exception as e:
+ logger.error(f"Task {task[0]}: Failed to wrap up task: {e}")
+ logger.info(
+ f"Finished task {task[0]} in {time.time() - start:.2f}s"
+ )
+ span.end()
+ limiter.release()
+ except asyncio.CancelledError:
+ logging.info("Poll loop cancelled, shutting down.")
+ except Exception as e:
+ logging.error(f"Error in task polling loop: {e}", exc_info=True)
+ await asyncio.sleep(5) # back off before retrying
if __name__ == "__main__":
diff --git a/workers/sipr/worker.py b/workers/sipr/worker.py
index c7c555e..afe8a61 100644
--- a/workers/sipr/worker.py
+++ b/workers/sipr/worker.py
@@ -1,6 +1,7 @@
"""SIPR (Set-Input Page Rank) module."""
import asyncio
+import json
import logging
import time
import uuid
@@ -15,7 +16,7 @@
save_message,
)
from shepherd_utils.otel import setup_tracer
-from shepherd_utils.shared import get_tasks, wrap_up_task
+from shepherd_utils.shared import get_tasks, handle_task_failure, wrap_up_task
# Queue name
STREAM = "sipr"
@@ -196,11 +197,6 @@ def distribute_weights(trapi_responses, target_nodes, logger):
async def sipr(task, logger: logging.Logger):
- start = time.time()
- workflow = [
- {"id": "sipr"},
- {"id": "sort_results_score"},
- ]
try:
# given a task, get the message from the db
logger.info("Getting message from db")
@@ -332,30 +328,48 @@ async def sipr(task, logger: logging.Logger):
except NotImplementedError:
logger.info("SIPR only supports Set Input Queries.")
- except Exception as e:
- logger.error(f"Something bad happened! {e}")
-
- await wrap_up_task(STREAM, GROUP, task, workflow, logger)
-
- logger.info(f"Finished task {task[0]} in {time.time() - start}")
+ task[1]["workflow"] = json.dumps(
+ [
+ {"id": "sipr"},
+ {"id": "sort_results_score"},
+ ]
+ )
async def process_task(task, parent_ctx, logger, limiter):
+ """Process a given task and ACK in redis."""
+ start = time.time()
span = tracer.start_span(STREAM, context=parent_ctx)
try:
await sipr(task, logger)
+ try:
+ await wrap_up_task(STREAM, GROUP, task, logger)
+ except Exception as e:
+ logger.error(f"Task {task[0]}: Failed to wrap up task: {e}")
+ except asyncio.CancelledError:
+ logger.warning(f"Task {task[0]} was cancelled.")
except Exception as e:
- logger.error(f"Something went wrong: {e}")
+ logger.error(f"Task {task[0]} failed with unhandled error: {e}", exc_info=True)
+ await handle_task_failure(STREAM, GROUP, task, logger)
finally:
span.end()
limiter.release()
+ logger.info(f"Task took {time.time() - start}")
async def poll_for_tasks():
- async for task, parent_ctx, logger, limiter in get_tasks(
- STREAM, GROUP, CONSUMER, TASK_LIMIT
- ):
- asyncio.create_task(process_task(task, parent_ctx, logger, limiter))
+ """On initialization, poll indefinitely for available tasks."""
+ while True:
+ try:
+ async for task, parent_ctx, logger, limiter in get_tasks(
+ STREAM, GROUP, CONSUMER, TASK_LIMIT
+ ):
+ asyncio.create_task(process_task(task, parent_ctx, logger, limiter))
+ except asyncio.CancelledError:
+ logging.info("Poll loop cancelled, shutting down.")
+ except Exception as e:
+ logging.error(f"Error in task polling loop: {e}", exc_info=True)
+ await asyncio.sleep(5) # back off before retrying
if __name__ == "__main__":
diff --git a/workers/sort_results_score/worker.py b/workers/sort_results_score/worker.py
index 00ce2f4..ee96ff0 100644
--- a/workers/sort_results_score/worker.py
+++ b/workers/sort_results_score/worker.py
@@ -6,7 +6,7 @@
import time
import uuid
from shepherd_utils.db import get_message, save_message, get_query_state
-from shepherd_utils.shared import get_tasks, wrap_up_task
+from shepherd_utils.shared import get_tasks, handle_task_failure, wrap_up_task
from shepherd_utils.otel import setup_tracer
# Queue name
@@ -18,7 +18,6 @@
async def sort_results_score(task, logger: logging.Logger):
- start = time.time()
# given a task, get the message from the db
response_id = task[1]["response_id"]
workflow = json.loads(task[1]["workflow"])
@@ -27,53 +26,64 @@ async def sort_results_score(task, logger: logging.Logger):
current_op = workflow[0]
aord = current_op.get("ascending_or_descending", "descending")
reverse = aord == "descending"
- try:
- for ind, result in enumerate(results):
- message["message"]["results"][ind]["analyses"] = sorted(
- result["analyses"],
- key=lambda x: x.get("score", 0),
- reverse=reverse,
- )
- if reverse:
- message["message"]["results"] = sorted(
- results,
- key=lambda x: x["analyses"][0].get("score", 0),
- reverse=reverse,
- )
- else:
- message["message"]["results"] = sorted(
- results,
- key=lambda x: x["analyses"][-1].get("score", 0),
- reverse=reverse,
- )
- except KeyError as e:
- # can't find the right structure of message
- err = f"Error sorting results: {e}"
- logger.error(err)
- raise KeyError(err)
+ for ind, result in enumerate(results):
+ message["message"]["results"][ind]["analyses"] = sorted(
+ result["analyses"],
+ key=lambda x: x.get("score", 0),
+ reverse=reverse,
+ )
+ if reverse:
+ message["message"]["results"] = sorted(
+ results,
+ key=lambda x: x["analyses"][0].get("score", 0),
+ reverse=reverse,
+ )
+ else:
+ message["message"]["results"] = sorted(
+ results,
+ key=lambda x: x["analyses"][-1].get("score", 0),
+ reverse=reverse,
+ )
logger.info("Returning sorted results.")
# save merged message back to db
await save_message(response_id, message, logger)
- await wrap_up_task(STREAM, GROUP, task, workflow, logger)
- logger.info(f"Finished task {task[0]} in {time.time() - start}")
-
async def process_task(task, parent_ctx, logger, limiter):
+ """Process a given task and ACK in redis."""
+ start = time.time()
span = tracer.start_span(STREAM, context=parent_ctx)
try:
await sort_results_score(task, logger)
+ try:
+ await wrap_up_task(STREAM, GROUP, task, logger)
+ except Exception as e:
+ logger.error(f"Task {task[0]}: Failed to wrap up task: {e}")
+ except asyncio.CancelledError:
+ logger.warning(f"Task {task[0]} was cancelled.")
+ except Exception as e:
+ logger.error(f"Task {task[0]} failed with unhandled error: {e}", exc_info=True)
+ await handle_task_failure(STREAM, GROUP, task, logger)
finally:
span.end()
limiter.release()
+ logger.info(f"Task took {time.time() - start}")
async def poll_for_tasks():
- async for task, parent_ctx, logger, limiter in get_tasks(
- STREAM, GROUP, CONSUMER, TASK_LIMIT
- ):
- asyncio.create_task(process_task(task, parent_ctx, logger, limiter))
+ """On initialization, poll indefinitely for available tasks."""
+ while True:
+ try:
+ async for task, parent_ctx, logger, limiter in get_tasks(
+ STREAM, GROUP, CONSUMER, TASK_LIMIT
+ ):
+ asyncio.create_task(process_task(task, parent_ctx, logger, limiter))
+ except asyncio.CancelledError:
+ logging.info("Poll loop cancelled, shutting down.")
+ except Exception as e:
+ logging.error(f"Error in task polling loop: {e}", exc_info=True)
+ await asyncio.sleep(5) # back off before retrying
if __name__ == "__main__":