Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion shepherd_server/openapi-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ contact:
x-id: https://github.com/maximusunc
x-role: responsible developer
description: '<img src="/static/favicon.png" width="200px"><br /><br />Shepherd: Translator Autonomous Relay Agent Platform'
version: 0.6.7
version: 0.6.8
servers:
- description: Default server
url: https://shepherd.renci.org
Expand Down
17 changes: 8 additions & 9 deletions shepherd_utils/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,16 +184,14 @@ async def get_message(
"""Get the message from db."""
message = {}
start = time.time()
try:
# print(f"Putting {query_id} on {ara_target} stream")
message = await data_db_client.get(message_id)
if message is not None:
start_decomp = time.time()
message = orjson.loads(zstandard.decompress(message))
logger.debug(f"Decompression took {time.time() - start_decomp}")
except Exception as e:
message = await data_db_client.get(message_id)
if message is None:
# failed to get message from db
logger.error(f"Failed to get {message_id} from db: {e}")
raise Exception(f"Failed to get {message_id} from db: {e}")

start_decomp = time.time()
message = orjson.loads(zstandard.decompress(message))
logger.debug(f"Decompression took {time.time() - start_decomp}")
logger.debug(f"Getting message took {time.time() - start} seconds")
return message

Expand Down Expand Up @@ -341,6 +339,7 @@ async def get_running_callbacks(
continue
except Exception as e:
logger.error(f"Failed to get running lookups: {e}")
raise
return running_lookups


Expand Down
27 changes: 26 additions & 1 deletion shepherd_utils/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,11 @@ async def wrap_up_task(
stream: str,
group: str,
task: Tuple[str, dict],
workflow: List[dict],
logger: logging.Logger,
):
"""Call the next task and mark this one as complete."""
workflow = json.loads(task[1]["workflow"])
logger.info(workflow)
# remove the operation we just did
if stream == workflow[0]["id"]:
# make sure the worker is in the workflow
Expand Down Expand Up @@ -106,6 +107,30 @@ async def wrap_up_task(
await save_logs(task[1]["response_id"], logger)


async def handle_task_failure(
stream: str,
group: str,
task: Tuple[str, dict],
logger: logging.Logger,
) -> None:
"""Handle any full query failures."""
await mark_task_as_complete(stream, group, task[0], logger)
await save_logs(task[1]["response_id"], logger)
logger.error("Sending task straight to finish_query.")
await add_task(
"finish_query",
{
"query_id": task[1]["query_id"],
"response_id": task[1]["response_id"],
"workflow": "[]",
"log_level": task[1].get("log_level", 20),
"otel": task[1]["otel"],
"status": "ERROR",
},
logger,
)


def recursive_get_edge_support_graphs(
edge: str,
edges: set,
Expand Down
58 changes: 37 additions & 21 deletions workers/aragorn/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from shepherd_utils.db import get_message
from shepherd_utils.otel import setup_tracer
from shepherd_utils.shared import get_tasks, wrap_up_task
from shepherd_utils.shared import get_tasks, handle_task_failure, wrap_up_task

# Queue name
STREAM = "aragorn"
Expand All @@ -31,27 +31,27 @@ def examine_query(message):
# this can still fail if the input looks like e.g.:
# "query_graph": None
qedges = message.get("message", {}).get("query_graph", {}).get("edges", {})
except:
except KeyError:
qedges = {}
try:
# this can still fail if the input looks like e.g.:
# "query_graph": None
qpaths = message.get("message", {}).get("query_graph", {}).get("paths", {})
except:
except KeyError:
qpaths = {}
if len(qpaths) > 1:
raise Exception("Only a single path is supported", 400)
raise Exception("Only a single path is supported")
if (len(qpaths) > 0) and (len(qedges) > 0):
raise Exception("Mixed mode pathfinder queries are not supported", 400)
raise Exception("Mixed mode pathfinder queries are not supported")
pathfinder = len(qpaths) == 1
n_infer_edges = 0
for edge_id in qedges:
if qedges.get(edge_id, {}).get("knowledge_type", "lookup") == "inferred":
n_infer_edges += 1
if n_infer_edges > 1 and n_infer_edges:
raise Exception("Only a single infer edge is supported", 400)
raise Exception("Only a single infer edge is supported")
if (n_infer_edges > 0) and (n_infer_edges < len(qedges)):
raise Exception("Mixed infer and lookup queries not supported", 400)
raise Exception("Mixed infer and lookup queries not supported")
infer = n_infer_edges == 1
if not infer:
return infer, None, None, pathfinder
Expand All @@ -64,23 +64,19 @@ def examine_query(message):
else:
question_node = qnode_id
if answer_node is None:
raise Exception("Both nodes of creative edge pinned", 400)
raise Exception("Both nodes of creative edge pinned")
if question_node is None:
raise Exception("No nodes of creative edge pinned", 400)
raise Exception("No nodes of creative edge pinned")
return infer, question_node, answer_node, pathfinder


async def aragorn(task, logger: logging.Logger):
start = time.time()
"""Examine and define the Aragorn workflow."""
# given a task, get the message from the db
query_id = task[1]["query_id"]
workflow = json.loads(task[1]["workflow"])
message = await get_message(query_id, logger)
try:
infer, question_qnode, answer_qnode, pathfinder = examine_query(message)
except Exception as e:
logger.error(e)
return None, 500
infer, question_qnode, answer_qnode, pathfinder = examine_query(message)

if workflow is None:
if infer:
Expand Down Expand Up @@ -113,24 +109,44 @@ async def aragorn(task, logger: logging.Logger):
{"id": "filter_kgraph_orphans"},
]

await wrap_up_task(STREAM, GROUP, task, workflow, logger)
logger.info(f"Task took {time.time() - start}")
task[1]["workflow"] = json.dumps(workflow)


async def process_task(task, parent_ctx, logger, limiter):
"""Process a given task and ACK in redis."""
start = time.time()
span = tracer.start_span(STREAM, context=parent_ctx)
try:
await aragorn(task, logger)
try:
logger.info(task)
await wrap_up_task(STREAM, GROUP, task, logger)
except Exception as e:
logger.error(f"Task {task[0]}: Failed to wrap up task: {e}")
except asyncio.CancelledError:
logger.warning(f"Task {task[0]} was cancelled.")
except Exception as e:
logger.error(f"Task {task[0]} failed with unhandled error: {e}", exc_info=True)
await handle_task_failure(STREAM, GROUP, task, logger)
finally:
span.end()
limiter.release()
logger.info(f"Task took {time.time() - start}")


async def poll_for_tasks():
async for task, parent_ctx, logger, limiter in get_tasks(
STREAM, GROUP, CONSUMER, TASK_LIMIT
):
asyncio.create_task(process_task(task, parent_ctx, logger, limiter))
"""On initialization, poll indefinitely for available tasks."""
while True:
try:
async for task, parent_ctx, logger, limiter in get_tasks(
STREAM, GROUP, CONSUMER, TASK_LIMIT
):
asyncio.create_task(process_task(task, parent_ctx, logger, limiter))
except asyncio.CancelledError:
logging.info("Poll loop cancelled, shutting down.")
except Exception as e:
logging.error(f"Error in task polling loop: {e}", exc_info=True)
await asyncio.sleep(5) # back off before retrying


if __name__ == "__main__":
Expand Down
59 changes: 40 additions & 19 deletions workers/aragorn_lookup/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
save_message,
)
from shepherd_utils.otel import setup_tracer
from shepherd_utils.shared import add_task, get_tasks, wrap_up_task
from shepherd_utils.shared import add_task, get_tasks, handle_task_failure, wrap_up_task

# Queue name
STREAM = "aragorn.lookup"
Expand All @@ -46,17 +46,17 @@ def examine_query(message):
# this can still fail if the input looks like e.g.:
# "query_graph": None
qedges = message.get("message", {}).get("query_graph", {}).get("edges", {})
except:
except KeyError:
qedges = {}
n_infer_edges = 0
for edge_id in qedges:
if qedges.get(edge_id, {}).get("knowledge_type", "lookup") == "inferred":
n_infer_edges += 1
pathfinder = n_infer_edges == 3
if n_infer_edges > 1 and n_infer_edges and not pathfinder:
raise Exception("Only a single infer edge is supported", 400)
raise Exception("Only a single infer edge is supported")
if (n_infer_edges > 0) and (n_infer_edges < len(qedges)):
raise Exception("Mixed infer and lookup queries not supported", 400)
raise Exception("Mixed infer and lookup queries not supported")
infer = n_infer_edges == 1
if not infer:
return infer, None, None, pathfinder
Expand All @@ -69,9 +69,9 @@ def examine_query(message):
else:
question_node = qnode_id
if answer_node is None:
raise Exception("Both nodes of creative edge pinned", 400)
raise Exception("Both nodes of creative edge pinned")
if question_node is None:
raise Exception("No nodes of creative edge pinned", 400)
raise Exception("No nodes of creative edge pinned")
return infer, question_node, answer_node, pathfinder


Expand Down Expand Up @@ -109,11 +109,10 @@ async def run_async_lookup(


async def aragorn_lookup(task, logger: logging.Logger):
start = time.time()
"""Do Aragorn lookup operation."""
# given a task, get the message from the db
query_id = task[1]["query_id"]
response_id = task[1]["response_id"]
workflow = json.loads(task[1]["workflow"])
message = await get_message(query_id, logger)
parameters = message.get("parameters") or {}
parameters["timeout"] = parameters.get("timeout", settings.lookup_timeout)
Expand Down Expand Up @@ -241,9 +240,13 @@ async def aragorn_lookup(task, logger: logging.Logger):
start_time = time.time()
running_callback_ids = [""]
while time.time() - start_time < MAX_QUERY_TIME:
# see if there are existing lookups going
running_callback_ids = await get_running_callbacks(query_id, logger)
# logger.info(f"Got back {len(running_callback_ids)} running lookups")
try:
# see if there are existing lookups going
running_callback_ids = await get_running_callbacks(query_id, logger)
except Exception:
# Brief backoff then retry the check rather than giving up
await asyncio.sleep(5)
continue
# if there are, continue to wait
if len(running_callback_ids) > 0:
await asyncio.sleep(1)
Expand All @@ -260,9 +263,6 @@ async def aragorn_lookup(task, logger: logging.Logger):
# logger.warning(f"Running callbacks: {running_callback_ids}")
await cleanup_callbacks(query_id, logger)

await wrap_up_task(STREAM, GROUP, task, workflow, logger)
logger.info(f"Finished task {task[0]} in {time.time() - start}")


def get_infer_parameters(input_message):
"""Given an infer input message, return the parameters needed to run the infer.
Expand Down Expand Up @@ -386,20 +386,41 @@ def expand_aragorn_query(input_message, logger: logging.Logger):
return messages


async def process_task(task, parent_ctx, logger, limiter):
async def process_task(task, parent_ctx, logger: logging.Logger, limiter):
"""Process a given task and ACK in redis."""
start = time.time()
span = tracer.start_span(STREAM, context=parent_ctx)
try:
await aragorn_lookup(task, logger)
# Always wrap up the task to ACK it in the broker
try:
await wrap_up_task(STREAM, GROUP, task, logger)
except Exception as e:
logger.error(f"Task {task[0]}: Failed to wrap up task: {e}")
except asyncio.CancelledError:
logger.warning(f"Task {task[0]} was cancelled")
except Exception as e:
logger.error(f"Task {task[0]} failed with unhandled error: {e}", exc_info=True)
await handle_task_failure(STREAM, GROUP, task, logger)
finally:
span.end()
limiter.release()
logger.info(f"Finished task {task[0]} in {time.time() - start}")


async def poll_for_tasks():
async for task, parent_ctx, logger, limiter in get_tasks(
STREAM, GROUP, CONSUMER, TASK_LIMIT
):
asyncio.create_task(process_task(task, parent_ctx, logger, limiter))
"""On initialization, poll indefinitely for available tasks."""
while True:
try:
async for task, parent_ctx, logger, limiter in get_tasks(
STREAM, GROUP, CONSUMER, TASK_LIMIT
):
asyncio.create_task(process_task(task, parent_ctx, logger, limiter))
except asyncio.CancelledError:
logging.info("Poll loop cancelled, shutting down.")
except Exception as e:
logging.error(f"Error in task polling loop: {e}", exc_info=True)
await asyncio.sleep(5) # back off before retrying


if __name__ == "__main__":
Expand Down
Loading
Loading