diff --git a/services/data/postgres_async_db.py b/services/data/postgres_async_db.py index a3faf2ad..34cd740b 100644 --- a/services/data/postgres_async_db.py +++ b/services/data/postgres_async_db.py @@ -95,7 +95,7 @@ async def _init(self, db_conf: DBConfiguration, create_triggers=DB_TRIGGER_CREAT break # Break the retry loop except Exception as e: - self.logger.exception("Exception occured") + self.logger.exception("Exception occurred") if retries - i <= 1: raise e time.sleep(connection_retry_wait_time_seconds) @@ -296,13 +296,12 @@ async def create_record(self, record_dict): return aiopg_exception_handling(error) async def update_row(self, filter_dict={}, update_dict={}): + query_params = {} # generate where clause filters = [] for col_name, col_val in filter_dict.items(): - v = str(col_val).strip("'") - if not v.isnumeric(): - v = "'" + v + "'" - filters.append(col_name + "=" + str(v)) + query_params['_filter_%s' % col_name] = col_val + filters.append('%s = %%(_filter_%s)s' % (col_name, col_name)) seperator = " and " where_clause = "" @@ -311,11 +310,12 @@ async def update_row(self, filter_dict={}, update_dict={}): sets = [] for col_name, col_val in update_dict.items(): - sets.append(col_name + " = " + str(col_val)) + query_params['_set_%s' % col_name] = col_val + sets.append('%s = %%(_filter_%s)s' % (col_name, col_name)) set_seperator = ", " set_clause = "" - if bool(filter_dict): + if bool(sets): set_clause = set_seperator.join(sets) update_sql = """ UPDATE {0} SET {1} WHERE {2}; @@ -326,7 +326,7 @@ async def update_row(self, filter_dict={}, update_dict={}): cursor_factory=psycopg2.extras.DictCursor ) ) as cur: - await cur.execute(update_sql) + await cur.execute(update_sql, query_params) if cur.rowcount < 1: return DBResponse(response_code=404, body={"msg": "could not find row"}) @@ -338,7 +338,7 @@ async def update_row(self, filter_dict={}, update_dict={}): cur.close() return DBResponse(response_code=200, body=body) except (Exception, psycopg2.DatabaseError) as error: - self.db.logger.exception("Exception occured") + self.db.logger.exception("Exception occurred") return aiopg_exception_handling(error) @@ -466,6 +466,10 @@ class AsyncFlowTablePostgres(AsyncPostgresTable): ) _row_type = FlowRow + @staticmethod + def get_filter_dict(flow_id: str): + return {"flow_id": flow_id} + async def add_flow(self, flow: FlowRow): dict = { "flow_id": flow.flow_id, @@ -476,7 +480,7 @@ async def add_flow(self, flow: FlowRow): return await self.create_record(dict) async def get_flow(self, flow_id: str): - filter_dict = {"flow_id": flow_id} + filter_dict = self.get_filter_dict(flow_id) return await self.get_records(filter_dict=filter_dict, fetch_single=True) async def get_all_flows(self): @@ -523,9 +527,13 @@ async def add_run(self, run: RunRow): } return await self.create_record(dict) - async def get_run(self, flow_id: str, run_id: str, expanded: bool = False): + @staticmethod + def get_filter_dict(flow_id: str, run_id: str): key, value = translate_run_key(run_id) - filter_dict = {"flow_id": flow_id, key: str(value)} + return {"flow_id": flow_id, key: str(value)} + + async def get_run(self, flow_id: str, run_id: str, expanded: bool = False): + filter_dict = self.get_filter_dict(flow_id, run_id) return await self.get_records(filter_dict=filter_dict, fetch_single=True, expanded=expanded) @@ -534,9 +542,7 @@ async def get_all_runs(self, flow_id: str): return await self.get_records(filter_dict=filter_dict) async def update_heartbeat(self, flow_id: str, run_id: str): - run_key, run_value = translate_run_key(run_id) - filter_dict = {"flow_id": flow_id, - run_key: str(run_value)} + filter_dict = self.get_filter_dict(flow_id, run_id) set_dict = { "last_heartbeat_ts": int(datetime.datetime.utcnow().timestamp()) } @@ -589,6 +595,15 @@ async def add_step(self, step_object: StepRow): } return await self.create_record(dict) + @staticmethod + def get_filter_dict(flow_id: str, run_id: str, step_name: str): + run_id_key, run_id_value = translate_run_key(run_id) + return { + "flow_id": flow_id, + run_id_key: run_id_value, + "step_name": step_name, + } + async def get_steps(self, flow_id: str, run_id: str): run_id_key, run_id_value = translate_run_key(run_id) filter_dict = {"flow_id": flow_id, @@ -596,12 +611,7 @@ async def get_steps(self, flow_id: str, run_id: str): return await self.get_records(filter_dict=filter_dict) async def get_step(self, flow_id: str, run_id: str, step_name: str): - run_id_key, run_id_value = translate_run_key(run_id) - filter_dict = { - "flow_id": flow_id, - run_id_key: run_id_value, - "step_name": step_name, - } + filter_dict = self.get_filter_dict(flow_id, run_id, step_name) return await self.get_records(filter_dict=filter_dict, fetch_single=True) @@ -651,36 +661,35 @@ async def add_task(self, task: TaskRow): } return await self.create_record(dict) - async def get_tasks(self, flow_id: str, run_id: str, step_name: str): + @staticmethod + def get_filter_dict(flow_id: str, run_id: str, step_name: str, task_id: str): run_id_key, run_id_value = translate_run_key(run_id) - filter_dict = { + task_id_key, task_id_value = translate_task_key(task_id) + return { "flow_id": flow_id, run_id_key: run_id_value, "step_name": step_name, + task_id_key: task_id_value, } - return await self.get_records(filter_dict=filter_dict) - async def get_task(self, flow_id: str, run_id: str, step_name: str, - task_id: str, expanded: bool = False): + async def get_tasks(self, flow_id: str, run_id: str, step_name: str): run_id_key, run_id_value = translate_run_key(run_id) - task_id_key, task_id_value = translate_task_key(task_id) filter_dict = { "flow_id": flow_id, run_id_key: run_id_value, "step_name": step_name, - task_id_key: task_id_value, } + return await self.get_records(filter_dict=filter_dict) + + async def get_task(self, flow_id: str, run_id: str, step_name: str, + task_id: str, expanded: bool = False): + filter_dict = self.get_filter_dict(flow_id, run_id, step_name, task_id) return await self.get_records(filter_dict=filter_dict, fetch_single=True, expanded=expanded) async def update_heartbeat(self, flow_id: str, run_id: str, step_name: str, task_id: str): - run_key, run_value = translate_run_key(run_id) - task_key, task_value = translate_task_key(task_id) - filter_dict = {"flow_id": flow_id, - run_key: str(run_value), - "step_name": step_name, - task_key: str(task_value)} + filter_dict = self.get_filter_dict(flow_id, run_id, step_name, task_id) set_dict = { "last_heartbeat_ts": int(datetime.datetime.utcnow().timestamp()) } @@ -757,6 +766,17 @@ async def add_metadata( } return await self.create_record(dict) + @staticmethod + def get_filter_dict(flow_id: str, run_id: str, step_name: str, task_id: str): + run_id_key, run_id_value = translate_run_key(run_id) + task_id_key, task_id_value = translate_task_key(task_id) + return { + "flow_id": flow_id, + run_id_key: run_id_value, + "step_name": step_name, + task_id_key: task_id_value, + } + async def get_metadata_in_runs(self, flow_id: str, run_id: str): run_id_key, run_id_value = translate_run_key(run_id) filter_dict = {"flow_id": flow_id, @@ -764,16 +784,9 @@ async def get_metadata_in_runs(self, flow_id: str, run_id: str): return await self.get_records(filter_dict=filter_dict) async def get_metadata( - self, flow_id: str, run_id: int, step_name: str, task_id: str + self, flow_id: str, run_id: str, step_name: str, task_id: str ): - run_id_key, run_id_value = translate_run_key(run_id) - task_id_key, task_id_value = translate_task_key(task_id) - filter_dict = { - "flow_id": flow_id, - run_id_key: run_id_value, - "step_name": step_name, - task_id_key: task_id_value, - } + filter_dict = self.get_filter_dict(flow_id, run_id, step_name, task_id) return await self.get_records(filter_dict=filter_dict) @@ -856,7 +869,20 @@ async def add_artifact( } return await self.create_record(dict) - async def get_artifacts_in_runs(self, flow_id: str, run_id: int): + @staticmethod + def get_filter_dict( + flow_id: str, run_id: str, step_name: str, task_id: str, name: str): + run_id_key, run_id_value = translate_run_key(run_id) + task_id_key, task_id_value = translate_task_key(task_id) + return { + "flow_id": flow_id, + run_id_key: run_id_value, + "step_name": step_name, + task_id_key: task_id_value, + '"name"': name, + } + + async def get_artifacts_in_runs(self, flow_id: str, run_id: str): run_id_key, run_id_value = translate_run_key(run_id) filter_dict = { "flow_id": flow_id, @@ -865,7 +891,7 @@ async def get_artifacts_in_runs(self, flow_id: str, run_id: int): return await self.get_records(filter_dict=filter_dict, ordering=self.ordering) - async def get_artifact_in_steps(self, flow_id: str, run_id: int, step_name: str): + async def get_artifact_in_steps(self, flow_id: str, run_id: str, step_name: str): run_id_key, run_id_value = translate_run_key(run_id) filter_dict = { "flow_id": flow_id, @@ -876,7 +902,7 @@ async def get_artifact_in_steps(self, flow_id: str, run_id: int, step_name: str) ordering=self.ordering) async def get_artifact_in_task( - self, flow_id: str, run_id: int, step_name: str, task_id: int + self, flow_id: str, run_id: str, step_name: str, task_id: str ): run_id_key, run_id_value = translate_run_key(run_id) task_id_key, task_id_value = translate_task_key(task_id) @@ -890,16 +916,8 @@ async def get_artifact_in_task( ordering=self.ordering) async def get_artifact( - self, flow_id: str, run_id: int, step_name: str, task_id: int, name: str + self, flow_id: str, run_id: str, step_name: str, task_id: str, name: str ): - run_id_key, run_id_value = translate_run_key(run_id) - task_id_key, task_id_value = translate_task_key(task_id) - filter_dict = { - "flow_id": flow_id, - run_id_key: run_id_value, - "step_name": step_name, - task_id_key: task_id_value, - '"name"': name, - } + filter_dict = self.get_filter_dict(flow_id, run_id, step_name, task_id, name) return await self.get_records(filter_dict=filter_dict, fetch_single=True, ordering=self.ordering) diff --git a/services/metadata_service/api/tag.py b/services/metadata_service/api/tag.py new file mode 100644 index 00000000..6454752f --- /dev/null +++ b/services/metadata_service/api/tag.py @@ -0,0 +1,133 @@ +from services.data import TaskRow +from services.data.db_utils import DBResponse +from services.data.postgres_async_db import AsyncPostgresDB +from services.metadata_service.api.utils import format_response, \ + handle_exceptions +import json + +import asyncio + + +class TagApi(object): + lock = asyncio.Lock() + + def __init__(self, app): + app.router.add_route( + "POST", + "/tags", + self.update_tags, + ) + self._db = AsyncPostgresDB.get_instance() + + def _get_table(self, type): + if type == 'flow': + return self._db.flow_table_postgres + elif type == 'run': + return self._db.run_table_postgres + elif type == 'step': + return self._db.step_table_postgres + elif type == 'task': + return self._db.task_table_postgres + elif type == 'artifact': + return self._db.artifact_table_postgres + else: + raise ValueError("cannot find table for type %s" % type) + + @handle_exceptions + @format_response + async def update_tags(self, request): + """ + --- + description: Update user-tags for objects + tags: + - Tags + parameters: + - name: "body" + in: "body" + description: "body" + required: true + schema: + type: array + items: + type: object + required: + - object_type + - id + - tag + - operation + properties: + object_type: + type: string + enum: [flow, run, step, task, artifact] + id: + type: string + operation: + type: string + enum: [add, remove] + tag: + type: string + user: + type: string + produces: + - application/json + responses: + "202": + description: successful operation. Return newly registered task + "404": + description: not found + "500": + description: internal server error + """ + body = await request.json() + results = [] + for o in body: + try: + table = self._get_table(o['object_type']) + pathspec = o['id'].split('/') + # Do some basic verification + if o['object_type'] == 'flow' and len(pathspec) != 1: + raise ValueError("invalid flow specification: %s" % o['id']) + elif o['object_type'] == 'run' and len(pathspec) != 2: + raise ValueError("invalid run specification: %s" % o['id']) + elif o['object_type'] == 'step' and len(pathspec) != 3: + raise ValueError("invalid step specification: %s" % o['id']) + elif o['object_type'] == 'task' and len(pathspec) != 4: + raise ValueError("invalid task specification: %s" % o['id']) + elif o['object_type'] == 'artifact' and len(pathspec) != 5: + raise ValueError("invalid artifact specification: %s" % o['id']) + obj_filter = table.get_filter_dict(*pathspec) + except ValueError as e: + return DBResponse(response_code=400, body=json.dumps( + {"message": "invalid input: %s" % str(e)})) + + # Now we can get the object + obj = await table.get_records( + filter_dict=obj_filter, fetch_single=True, expanded=True) + if obj.response_code != 200: + return DBResponse(response_code=obj.response_code, body=json.dumps( + {"message": "could not get object %s: %s" % (o['id'], obj.body)})) + + # At this point do some checks and update the tags + obj = obj.body + modified = False + if o['operation'] == 'add': + if o['tag'] not in obj['system_tags'] and o['tag'] not in obj['tags']: + modified = True + obj['tags'].append(o['tag']) + elif o['operation'] == 'remove': + if o['tag'] in obj['tags']: + modified = True + obj['tags'] = [x for x in obj['tags'] if x != o['tag']] + else: + return DBResponse(response_code=400, body=json.dumps( + {"message": "invalid tag operation %s" % o['operation']})) + if modified: + # We save the value back + result = await table.update_row(filter_dict=obj_filter, update_dict={ + 'tags': json.dumps(obj['tags'])}) + if result.response_code != 200: + return DBResponse(response_code=result.response_code, body=json.dumps( + {"message": "error updating tags for %s: %s" % (o['id'], result.body)})) + results.append(obj) + + return DBResponse(response_code=200, body=json.dumps(results)) diff --git a/services/metadata_service/api/utils.py b/services/metadata_service/api/utils.py index 527e9160..08529190 100644 --- a/services/metadata_service/api/utils.py +++ b/services/metadata_service/api/utils.py @@ -20,6 +20,8 @@ def format_response(func): @wraps(func) async def wrapper(*args, **kwargs): db_response = await func(*args, **kwargs) + if isinstance(db_response, web.Response): + return db_response return web.Response(status=db_response.response_code, body=json.dumps(db_response.body), headers=MultiDict( diff --git a/services/metadata_service/server.py b/services/metadata_service/server.py index 2197eece..f41b14f7 100644 --- a/services/metadata_service/server.py +++ b/services/metadata_service/server.py @@ -11,6 +11,7 @@ from .api.task import TaskApi from .api.artifact import ArtificatsApi from .api.admin import AuthApi +from .api.tag import TagApi from .api.metadata import MetadataApi from services.data.postgres_async_db import AsyncPostgresDB @@ -30,6 +31,7 @@ def app(loop=None, db_conf: DBConfiguration = None): MetadataApi(app) ArtificatsApi(app) AuthApi(app) + TagApi(app) setup_swagger(app) return app