diff --git a/packages/examples/cvat/exchange-oracle/Dockerfile b/packages/examples/cvat/exchange-oracle/Dockerfile index 86c54f1441..69e547480d 100644 --- a/packages/examples/cvat/exchange-oracle/Dockerfile +++ b/packages/examples/cvat/exchange-oracle/Dockerfile @@ -4,13 +4,17 @@ WORKDIR /app RUN apt-get update -y && \ apt-get install -y jq ffmpeg libsm6 libxext6 && \ - pip install --no-cache poetry + rm -rf /var/lib/apt/lists/* + +RUN pip install --no-cache poetry COPY pyproject.toml poetry.lock ./ -RUN poetry config virtualenvs.create false \ - && poetry install --no-interaction --no-ansi --no-root \ - && poetry cache clear pypi --all +RUN --mount=type=cache,target=/root/.cache \ + poetry config virtualenvs.create false && \ + poetry install --no-interaction --no-ansi --no-root + +RUN python -m pip uninstall -y poetry pip COPY . . diff --git a/packages/examples/cvat/exchange-oracle/README.md b/packages/examples/cvat/exchange-oracle/README.md index 7d70eeac82..d6d91ce62b 100644 --- a/packages/examples/cvat/exchange-oracle/README.md +++ b/packages/examples/cvat/exchange-oracle/README.md @@ -18,14 +18,14 @@ For deployment it is required to have PostgreSQL(v14.4) ### Run the oracle locally: -``` +```sh docker compose -f docker-compose.dev.yml up -d ./bin/start_dev.sh ``` or -``` +```sh docker compose -f docker-compose.dev.yml up -d ./bin/start_debug.sh ``` @@ -48,17 +48,17 @@ Example: [Alembic env file](https://github.com/humanprotocol/human-protocol/blob Adding new migration: -``` +```sh alembic revision --autogenerate -m "your-migration-name" ``` Upgrade: -``` +```sh alembic upgrade head ``` Downgrade: -``` +```sh alembic downgrade -{number of migrations} ``` @@ -72,7 +72,7 @@ Available at `/docs` route ### Tests To run tests -``` -docker compose -f docker-compose.test.yml up --build test --attach test --exit-code-from test && \ - docker compose -f docker-compose.test.yml down +```sh +docker compose -p "test" -f docker-compose.test.yml up --build test --attach test --exit-code-from test; \ + docker compose -p "test" -f docker-compose.test.yml down ``` \ No newline at end of file diff --git a/packages/examples/cvat/exchange-oracle/dockerfiles/test.Dockerfile b/packages/examples/cvat/exchange-oracle/dockerfiles/test.Dockerfile index f2341c8d22..eaa436f5d0 100644 --- a/packages/examples/cvat/exchange-oracle/dockerfiles/test.Dockerfile +++ b/packages/examples/cvat/exchange-oracle/dockerfiles/test.Dockerfile @@ -4,16 +4,20 @@ WORKDIR /app RUN apt-get update -y && \ apt-get install -y jq ffmpeg libsm6 libxext6 && \ - pip install --no-cache poetry + rm -rf /var/lib/apt/lists/* + +RUN pip install --no-cache poetry COPY pyproject.toml poetry.lock ./ -RUN poetry config virtualenvs.create false \ - && poetry install --no-interaction --no-ansi --no-root \ - && poetry cache clear pypi --all +RUN --mount=type=cache,target=/root/.cache \ + poetry config virtualenvs.create false && \ + poetry install --no-interaction --no-ansi --no-root + +RUN python -m pip uninstall -y poetry pip COPY . . RUN rm -f ./src/.env -CMD ["pytest", "-W", "ignore::DeprecationWarning", "-v"] \ No newline at end of file +CMD ["pytest"] \ No newline at end of file diff --git a/packages/examples/cvat/exchange-oracle/pyproject.toml b/packages/examples/cvat/exchange-oracle/pyproject.toml index 5dd04c1813..2170018394 100644 --- a/packages/examples/cvat/exchange-oracle/pyproject.toml +++ b/packages/examples/cvat/exchange-oracle/pyproject.toml @@ -134,6 +134,7 @@ ignore = [ "ANN001", # | "ANN003", # | "ARG001", # | + "FBT001", # Allow bool-annotated positional args in functions "SLF001", # Allow private attrs access "PLR2004", # Allow magic values "S", # security diff --git a/packages/examples/cvat/exchange-oracle/pytest.ini b/packages/examples/cvat/exchange-oracle/pytest.ini new file mode 100644 index 0000000000..bdf7142e2b --- /dev/null +++ b/packages/examples/cvat/exchange-oracle/pytest.ini @@ -0,0 +1,10 @@ +[pytest] +addopts = --verbose +filterwarnings = + ignore::DeprecationWarning:cvat_sdk.core + ignore::DeprecationWarning:human_protocol_sdk.storage + ignore:Field name \"sort\" shadows:UserWarning:pydantic._internal._fields + +python_files = test_*.py +python_classes = *Test +python_functions = test_* \ No newline at end of file diff --git a/packages/examples/cvat/exchange-oracle/src/core/types.py b/packages/examples/cvat/exchange-oracle/src/core/types.py index 2b67fe9fab..20a04c0b59 100644 --- a/packages/examples/cvat/exchange-oracle/src/core/types.py +++ b/packages/examples/cvat/exchange-oracle/src/core/types.py @@ -10,12 +10,6 @@ class Networks(int, Enum, metaclass=BetterEnumMeta): localhost = Config.localhost.chain_id -class CvatEventTypes(str, Enum, metaclass=BetterEnumMeta): - update_job = "update:job" - create_job = "create:job" - ping = "ping" - - class ProjectStatuses(str, Enum, metaclass=BetterEnumMeta): creation = "creation" annotation = "annotation" @@ -34,7 +28,6 @@ class TaskStatuses(str, Enum, metaclass=BetterEnumMeta): class JobStatuses(str, Enum, metaclass=BetterEnumMeta): new = "new" in_progress = "in progress" - rejected = "rejected" completed = "completed" @@ -47,13 +40,6 @@ class TaskTypes(str, Enum, metaclass=BetterEnumMeta): image_polygons = "image_polygons" -class CvatLabelTypes(str, Enum, metaclass=BetterEnumMeta): - tag = "tag" - points = "points" - rectangle = "rectangle" - polygon = "polygon" - - class OracleWebhookTypes(str, Enum, metaclass=BetterEnumMeta): exchange_oracle = "exchange_oracle" job_launcher = "job_launcher" diff --git a/packages/examples/cvat/exchange-oracle/src/crons/cvat/state_trackers.py b/packages/examples/cvat/exchange-oracle/src/crons/cvat/state_trackers.py index 04252e82bc..213d129a7a 100644 --- a/packages/examples/cvat/exchange-oracle/src/crons/cvat/state_trackers.py +++ b/packages/examples/cvat/exchange-oracle/src/crons/cvat/state_trackers.py @@ -49,8 +49,23 @@ def track_assignments(logger: logging.Logger) -> None: Tracks assignments: 1. Checks time for each active assignment 2. If an assignment is timed out, expires it - 3. If a project or task state is not "annotation", cancels assignments + 3. If an assignment is canceled, resets it + 4. If a project or task state is not "annotation", cancels assignments """ + + def _reset_job_after_assignment(session: Session, assignment: cvat_models.Assignment): + latest_assignment = cvat_service.get_latest_assignment_by_cvat_job_id( + session, assignment.cvat_job_id + ) + if latest_assignment.id == assignment.id: + # Avoid un-assigning if it's not the latest assignment + + cvat_api.update_job_assignee( + assignment.cvat_job_id, assignee_id=None + ) # note that calling it in a loop can take too much time + + cvat_service.update_job_status(session, assignment.job.id, status=JobStatuses.new) + with SessionLocal.begin() as session: assignments = cvat_service.get_unprocessed_expired_assignments( session, @@ -67,17 +82,27 @@ def track_assignments(logger: logging.Logger) -> None: ) ) - latest_assignment = cvat_service.get_latest_assignment_by_cvat_job_id( - session, assignment.cvat_job_id - ) - if latest_assignment.id == assignment.id: - # Avoid un-assigning if it's not the latest assignment + cvat_service.expire_assignment(session, assignment.id) + _reset_job_after_assignment(session, assignment) - cvat_api.update_job_assignee( - assignment.cvat_job_id, assignee_id=None - ) # note that calling it in a loop can take too much time + cvat_service.touch(session, cvat_models.Job, [a.job.id for a in assignments]) - cvat_service.expire_assignment(session, assignment.id) + with SessionLocal.begin() as session: + assignments = cvat_service.get_unprocessed_cancelled_assignments( + session, + limit=CronConfig.track_assignments_chunk_size, + for_update=ForUpdateParams(skip_locked=True), + ) + + for assignment in assignments: + logger.info( + "Finalizing the canceled assignment {} (user {}, job id {})".format( + assignment.id, + assignment.user_wallet_address, + assignment.cvat_job_id, + ) + ) + _reset_job_after_assignment(session, assignment) cvat_service.touch(session, cvat_models.Job, [a.job.id for a in assignments]) @@ -99,17 +124,8 @@ def track_assignments(logger: logging.Logger) -> None: ) ) - latest_assignment = cvat_service.get_latest_assignment_by_cvat_job_id( - session, assignment.cvat_job_id - ) - if latest_assignment.id == assignment.id: - # Avoid un-assigning if it's not the latest assignment - - cvat_api.update_job_assignee( - assignment.cvat_job_id, assignee_id=None - ) # note that calling it in a loop can take too much time - cvat_service.cancel_assignment(session, assignment.id) + _reset_job_after_assignment(session, assignment) cvat_service.touch(session, cvat_models.Job, [a.job.id for a in assignments]) diff --git a/packages/examples/cvat/exchange-oracle/src/cvat/api_calls.py b/packages/examples/cvat/exchange-oracle/src/cvat/api_calls.py index 02f442f4c2..5f9f780043 100644 --- a/packages/examples/cvat/exchange-oracle/src/cvat/api_calls.py +++ b/packages/examples/cvat/exchange-oracle/src/cvat/api_calls.py @@ -38,6 +38,26 @@ class RequestStatus(str, Enum, metaclass=BetterEnumMeta): FAILED = "Failed" +class JobStatus(str, Enum, metaclass=BetterEnumMeta): + new = "new" + in_progress = "in progress" + rejected = "rejected" + completed = "completed" + + +class LabelType(str, Enum, metaclass=BetterEnumMeta): + tag = "tag" + points = "points" + rectangle = "rectangle" + polygon = "polygon" + + +class WebhookEventType(str, Enum, metaclass=BetterEnumMeta): + update_job = "update:job" + create_job = "create:job" + ping = "ping" + + def _request_annotations(endpoint: Endpoint, cvat_id: int, format_name: str) -> str: """ Requests annotations export. diff --git a/packages/examples/cvat/exchange-oracle/src/handlers/cvat_events.py b/packages/examples/cvat/exchange-oracle/src/handlers/cvat_events.py index 0378f6c9b4..ecec61ff7e 100644 --- a/packages/examples/cvat/exchange-oracle/src/handlers/cvat_events.py +++ b/packages/examples/cvat/exchange-oracle/src/handlers/cvat_events.py @@ -4,10 +4,11 @@ import src.models.cvat as models import src.services.cvat as cvat_service from src import db -from src.core.types import AssignmentStatuses, CvatEventTypes, JobStatuses, ProjectStatuses +from src.core.types import AssignmentStatuses, JobStatuses, ProjectStatuses from src.db import SessionLocal from src.db import errors as db_errors from src.log import ROOT_LOGGER_NAME +from src.schemas.cvat import CvatWebhook from src.utils.logging import get_function_logger module_logger_name = f"{ROOT_LOGGER_NAME}.cron.handler" @@ -16,6 +17,11 @@ def handle_update_job_event(payload: dict) -> None: logger = get_function_logger(module_logger_name) + if "state" not in payload.before_update: + return + + new_cvat_status = cvat_api.JobStatus(payload.job["state"]) + with SessionLocal.begin() as session: job_id = payload.job["id"] jobs = cvat_service.get_jobs_by_cvat_id(session, [job_id], for_update=True) @@ -27,77 +33,80 @@ def handle_update_job_event(payload: dict) -> None: job = jobs[0] - if "state" in payload.before_update: - job_assignments = job.assignments - new_status = JobStatuses(payload.job["state"]) + if job.status != JobStatuses.in_progress: + logger.warning( + f"Received a job update webhook for a job id {job_id} " + f"in the status {job.status}, ignoring " + ) + return - if not job_assignments: - logger.warning( - f"Received job #{job.cvat_id} status update: {new_status.value}. " - "No assignments for this job, ignoring the update" - ) - else: - webhook_time = parse_aware_datetime(payload.job["updated_date"]) - webhook_assignee_id = (payload.job["assignee"] or {}).get("id") + # ignore updates for any assignments except the last one + latest_assignment = cvat_service.get_latest_assignment_by_cvat_job_id( + session, job_id, for_update=True + ) + if not latest_assignment: + logger.warning( + f"Received job #{job.cvat_id} status update: {new_cvat_status.value}. " + "No assignments for this job, ignoring the update" + ) + return - job_assignments: list[models.Assignment] = sorted( - job_assignments, key=lambda a: a.created_at, reverse=True - ) - latest_assignment = job.assignments[0] - matching_assignment = next( - ( - a - for a in job_assignments - if a.user.cvat_id == webhook_assignee_id - if a.created_at < webhook_time - ), - None, - ) + webhook_time = parse_aware_datetime(payload.job["updated_date"]) + webhook_assignee_id = (payload.job["assignee"] or {}).get("id") + + matching_assignment = next( + ( + a + for a in [latest_assignment] + if a.user.cvat_id == webhook_assignee_id + if a.created_at < webhook_time + ), + None, + ) - if not matching_assignment: - logger.warning( - f"Received job #{job.cvat_id} status update: {new_status.value}. " - "Can't find a matching assignment, ignoring the update" - ) - elif matching_assignment.is_finished: - if matching_assignment.status == AssignmentStatuses.created: - logger.warning( - f"Received job #{job.cvat_id} status update: {new_status.value}. " - "Assignment is expired, rejecting the update" - ) - cvat_service.expire_assignment(session, matching_assignment.id) - cvat_service.touch(session, models.Job, [matching_assignment.job.id]) - - if matching_assignment.id == latest_assignment.id: - cvat_api.update_job_assignee(job.cvat_id, assignee_id=None) - - else: - logger.info( - f"Received job #{job.cvat_id} status update: {new_status.value}. " - "Assignment is already finished, ignoring the update" - ) - elif ( - new_status == JobStatuses.completed - and matching_assignment.id == latest_assignment.id - and matching_assignment.status == AssignmentStatuses.created - ): - logger.info( - f"Received job #{job.cvat_id} status update: {new_status.value}. " - "Completing the assignment" - ) - cvat_service.complete_assignment( - session, matching_assignment.id, completed_at=webhook_time - ) - cvat_service.update_job_status(session, job.id, new_status) - cvat_service.touch(session, models.Job, [job.id]) + if not matching_assignment: + logger.warning( + f"Received job #{job.cvat_id} status update: {new_cvat_status.value}. " + "No matching assignment or the assignment is too old, ignoring the update" + ) + elif matching_assignment.is_finished: + if matching_assignment.status == AssignmentStatuses.created: + logger.warning( + f"Received job #{job.cvat_id} status update: {new_cvat_status.value}. " + "Assignment is expired, rejecting the update" + ) + cvat_service.expire_assignment(session, matching_assignment.id) + if matching_assignment.id == latest_assignment.id: cvat_api.update_job_assignee(job.cvat_id, assignee_id=None) + cvat_service.update_job_status(session, job.id, status=JobStatuses.new) - else: - logger.info( - f"Received job #{job.cvat_id} status update: {new_status.value}. " - "Ignoring the update" - ) + cvat_service.touch(session, models.Job, [job.id]) + else: + logger.info( + f"Received job #{job.cvat_id} status update: {new_cvat_status.value}. " + "Assignment is already finished, ignoring the update" + ) + elif ( + new_cvat_status == cvat_api.JobStatus.completed + and matching_assignment.id == latest_assignment.id + and matching_assignment.is_finished == False + ): + logger.info( + f"Received job #{job.cvat_id} status update: {new_cvat_status.value}. " + "Completing the assignment" + ) + cvat_service.complete_assignment( + session, matching_assignment.id, completed_at=webhook_time + ) + cvat_api.update_job_assignee(job.cvat_id, assignee_id=None) + cvat_service.update_job_status(session, job.id, status=JobStatuses.completed) + cvat_service.touch(session, models.Job, [job.id]) + else: + logger.info( + f"Received job #{job.cvat_id} status update: {new_cvat_status.value}. " + "Ignoring the update" + ) def handle_create_job_event(payload: dict) -> None: @@ -167,11 +176,11 @@ def handle_create_job_event(payload: dict) -> None: ) -def cvat_webhook_handler(cvat_webhook: dict) -> None: +def cvat_webhook_handler(cvat_webhook: CvatWebhook) -> None: match cvat_webhook.event: - case CvatEventTypes.update_job.value: + case cvat_api.WebhookEventType.update_job.value: handle_update_job_event(cvat_webhook) - case CvatEventTypes.create_job.value: + case cvat_api.WebhookEventType.create_job.value: handle_create_job_event(cvat_webhook) - case CvatEventTypes.ping.value: + case cvat_api.WebhookEventType.ping.value: pass diff --git a/packages/examples/cvat/exchange-oracle/src/handlers/job_creation.py b/packages/examples/cvat/exchange-oracle/src/handlers/job_creation.py index 566d69a959..1f94262a3f 100644 --- a/packages/examples/cvat/exchange-oracle/src/handlers/job_creation.py +++ b/packages/examples/cvat/exchange-oracle/src/handlers/job_creation.py @@ -33,7 +33,7 @@ from src.chain.escrow import get_escrow_manifest from src.core.config import Config from src.core.storage import compose_data_bucket_filename, compose_data_bucket_prefix -from src.core.types import CvatLabelTypes, TaskStatuses, TaskTypes +from src.core.types import TaskStatuses, TaskTypes from src.db import SessionLocal from src.log import ROOT_LOGGER_NAME from src.models.cvat import Project @@ -54,12 +54,12 @@ module_logger = f"{ROOT_LOGGER_NAME}.cron.cvat" LABEL_TYPE_MAPPING = { - TaskTypes.image_label_binary: CvatLabelTypes.tag, - TaskTypes.image_points: CvatLabelTypes.points, - TaskTypes.image_boxes: CvatLabelTypes.rectangle, - TaskTypes.image_polygons: CvatLabelTypes.polygon, - TaskTypes.image_boxes_from_points: CvatLabelTypes.rectangle, - TaskTypes.image_skeletons_from_boxes: CvatLabelTypes.points, + TaskTypes.image_label_binary: cvat_api.LabelType.tag, + TaskTypes.image_points: cvat_api.LabelType.points, + TaskTypes.image_boxes: cvat_api.LabelType.rectangle, + TaskTypes.image_polygons: cvat_api.LabelType.polygon, + TaskTypes.image_boxes_from_points: cvat_api.LabelType.rectangle, + TaskTypes.image_skeletons_from_boxes: cvat_api.LabelType.points, } DM_DATASET_FORMAT_MAPPING = { @@ -230,7 +230,7 @@ def _setup_gt_job_for_cvat_task( with TemporaryDirectory() as tmp_dir: export_dir = Path(tmp_dir) / "export" - gt_dataset.export(save_dir=str(export_dir), save_images=False, format=dm_export_format) + gt_dataset.export(save_dir=str(export_dir), save_media=False, format=dm_export_format) annotations_archive_path = Path(tmp_dir) / "annotations.zip" with annotations_archive_path.open("wb") as annotations_archive: diff --git a/packages/examples/cvat/exchange-oracle/src/handlers/job_export.py b/packages/examples/cvat/exchange-oracle/src/handlers/job_export.py index 6b93567965..a3d67c7eb6 100644 --- a/packages/examples/cvat/exchange-oracle/src/handlers/job_export.py +++ b/packages/examples/cvat/exchange-oracle/src/handlers/job_export.py @@ -130,7 +130,7 @@ def _parse_dataset(self, ann_descriptor: FileDescriptor, dataset_dir: str) -> dm return dm.Dataset.import_from(dataset_dir, self.input_format) def _export_dataset(self, dataset: dm.Dataset, output_dir: str): - dataset.export(output_dir, self.output_format, save_images=False) + dataset.export(output_dir, self.output_format, save_media=False) def _process_dataset( self, dataset: dm.Dataset, *, ann_descriptor: FileDescriptor diff --git a/packages/examples/cvat/exchange-oracle/src/services/cvat.py b/packages/examples/cvat/exchange-oracle/src/services/cvat.py index 06f3ef5aa7..09768227dd 100644 --- a/packages/examples/cvat/exchange-oracle/src/services/cvat.py +++ b/packages/examples/cvat/exchange-oracle/src/services/cvat.py @@ -12,6 +12,7 @@ from sqlalchemy import delete, func, literal, select, update from sqlalchemy.dialects.postgresql import insert from sqlalchemy.orm import Session +from sqlalchemy.sql.functions import coalesce from src.core.types import ( AssignmentStatuses, @@ -761,7 +762,7 @@ def get_free_job( for_update: bool | ForUpdateParams = False, ) -> Job | None: """ - Returns the first available job that wasn't previously assigned to that user_walled_address. + Returns the first available job that wasn't previously assigned to that user_wallet_address. """ return ( _maybe_for_update(session.query(Job), enable=for_update) @@ -772,14 +773,7 @@ def get_free_job( & (Project.status == ProjectStatuses.annotation) ), Job.status == JobStatuses.new, - ~Job.assignments.any( - ( - (Assignment.status == AssignmentStatuses.created.value) - & (Assignment.completed_at == None) - & (utcnow() < Assignment.expires_at) - ) - | (Assignment.user_wallet_address == user_wallet_address), - ), + ~Job.assignments.any(Assignment.user_wallet_address == user_wallet_address), ) .first() ) @@ -881,13 +875,28 @@ def get_unprocessed_expired_assignments( ) +def get_unprocessed_cancelled_assignments( + session: Session, *, limit: int = 10, for_update: bool | ForUpdateParams = False +) -> list[Assignment]: + return ( + _maybe_for_update(session.query(Assignment), enable=for_update) + .where( + (Assignment.job.has(Job.status == JobStatuses.in_progress.value)) + & (Assignment.status == AssignmentStatuses.canceled.value) + ) + .limit(limit) + .all() + ) + + def get_active_assignments( session: Session, *, limit: int = 10, for_update: bool | ForUpdateParams = False ) -> list[Assignment]: return ( _maybe_for_update(session.query(Assignment), enable=for_update) .where( - (Assignment.status == AssignmentStatuses.created.value) + (Assignment.job.has(Job.status == JobStatuses.in_progress.value)) + & (Assignment.status == AssignmentStatuses.created.value) & (Assignment.completed_at == None) & (Assignment.expires_at <= utcnow()) ) @@ -1020,7 +1029,14 @@ def touch( if time is None: time = utcnow() - session.execute(update(cls).where(cls.id.in_(ids)).values({cls.updated_at: time})) + session.execute( + update(cls) + .where( + cls.id.in_(ids), + coalesce(cls.updated_at, datetime.min) < time, + ) + .values({cls.updated_at: time}) + ) if touch_parents: touch_parent_objects(session, cls, ids, time=time) @@ -1033,6 +1049,9 @@ def touch_parent_objects( *, time: datetime | None = None, ): + if time is None: + time = utcnow() + while issubclass(cls, ChildOf): parent_cls = cls.parent_cls foreign_key_column = next(iter(cls.parent.property.local_columns)) @@ -1044,7 +1063,8 @@ def touch_parent_objects( select(foreign_key_column) .where(cls.id.in_(ids)) .where(foreign_key_column.is_not(None)) - ) + ), + coalesce(parent_cls.updated_at, datetime.min) < time, ) .values({parent_cls.updated_at: time}) .returning(parent_cls.id) diff --git a/packages/examples/cvat/exchange-oracle/src/services/exchange.py b/packages/examples/cvat/exchange-oracle/src/services/exchange.py index 23a7b8e1cc..7d514597a0 100644 --- a/packages/examples/cvat/exchange-oracle/src/services/exchange.py +++ b/packages/examples/cvat/exchange-oracle/src/services/exchange.py @@ -1,9 +1,11 @@ +from contextlib import suppress from datetime import timedelta import src.cvat.api_calls as cvat_api import src.services.cvat as cvat_service -from src.core.types import Networks, ProjectStatuses, TaskTypes +from src.core.types import JobStatuses, Networks, ProjectStatuses, TaskTypes from src.db import SessionLocal +from src.db.utils import ForUpdateParams from src.models.cvat import Job from src.utils.assignments import get_default_assignment_timeout from src.utils.requests import get_or_404 @@ -23,7 +25,7 @@ def create_assignment(escrow_address: str, chain_id: Networks, wallet_address: s user = get_or_404( cvat_service.get_user_by_id(session, wallet_address, for_update=True), wallet_address, - "user", + object_type_name="user", ) if cvat_service.has_active_user_assignments( @@ -43,7 +45,7 @@ def create_assignment(escrow_address: str, chain_id: Networks, wallet_address: s session, escrow_address, status_in=[ProjectStatuses.annotation] ), escrow_address, - "job", + object_type_name="job", ) unassigned_job = cvat_service.get_free_job( @@ -51,7 +53,7 @@ def create_assignment(escrow_address: str, chain_id: Networks, wallet_address: s escrow_address=escrow_address, chain_id=chain_id.value, user_wallet_address=wallet_address, - for_update=True, + for_update=ForUpdateParams(skip_locked=True), # lock the job to be able to make a rollback if CVAT requests fail # can potentially be optimized to make less DB requests # and rely only on assignment expiration @@ -73,6 +75,7 @@ def create_assignment(escrow_address: str, chain_id: Networks, wallet_address: s ), ) + cvat_service.update_job_status(session, unassigned_job.id, status=JobStatuses.in_progress) cvat_service.touch(session, Job, [unassigned_job.id]) with cvat_api.api_client_context(cvat_api.get_api_client()): @@ -91,7 +94,9 @@ class NoAccessError(Exception): async def resign_assignment(assignment_id: str, wallet_address: str) -> None: with SessionLocal.begin() as session: assignments = cvat_service.get_assignments_by_id(session, [assignment_id], for_update=True) - assignment = get_or_404(next(iter(assignments), None), assignment_id, "assignment") + assignment = get_or_404( + next(iter(assignments), None), assignment_id, object_type_name="assignment" + ) # Can only resign from an active assignment in a job # TODO: maybe optimize to a single DB request @@ -103,12 +108,20 @@ async def resign_assignment(assignment_id: str, wallet_address: str) -> None: raise NoAccessError last_job_assignment = cvat_service.get_latest_assignment_by_cvat_job_id( - session, assignment.cvat_job_id, for_update=True + session, + assignment.cvat_job_id, + for_update=ForUpdateParams(skip_locked=True), ) - if assignment.id != last_job_assignment.id: + if not last_job_assignment or assignment.id != last_job_assignment.id: raise NoAccessError cvat_service.cancel_assignment(session, assignment_id) - job = assignment.job - cvat_service.touch(session, Job, [job.id]) # project|task rows are locked for update + # Try to update the status, but don't insist + with suppress(cvat_api.exceptions.ApiException): + cvat_api.update_job_assignee(assignment.cvat_job_id, assignee_id=None) + + # Update the job only if assignee was unset + cvat_service.update_job_status(session, assignment.job.id, status=JobStatuses.new) + + cvat_service.touch(session, Job, [assignment.job.id]) diff --git a/packages/examples/cvat/exchange-oracle/src/utils/requests.py b/packages/examples/cvat/exchange-oracle/src/utils/requests.py index ef2174f9b9..73519f3947 100644 --- a/packages/examples/cvat/exchange-oracle/src/utils/requests.py +++ b/packages/examples/cvat/exchange-oracle/src/utils/requests.py @@ -9,8 +9,8 @@ def get_or_404( obj: T | None, object_id: V, - object_type_name: str, *, + object_type_name: str, reason: str | None = None, ) -> T: if obj is None: diff --git a/packages/examples/cvat/exchange-oracle/tests/api/test_cvat_webhook_api.py b/packages/examples/cvat/exchange-oracle/tests/api/test_cvat_webhook_api.py index 2cfff389c8..3c54224f69 100644 --- a/packages/examples/cvat/exchange-oracle/tests/api/test_cvat_webhook_api.py +++ b/packages/examples/cvat/exchange-oracle/tests/api/test_cvat_webhook_api.py @@ -1,12 +1,15 @@ -from datetime import datetime, timedelta, timezone +from datetime import datetime, timedelta from unittest.mock import patch +import pytest from fastapi.testclient import TestClient from src.core.types import AssignmentStatuses, JobStatuses +from src.utils.time import utcnow +from tests.utils.constants import WALLET_ADDRESS1, WALLET_ADDRESS2 from tests.utils.setup_cvat import ( - add_asignment_to_db, + add_assignment_to_db, add_cvat_job_to_db, add_cvat_project_to_db, add_cvat_task_to_db, @@ -14,20 +17,19 @@ get_cvat_job_from_db, ) -api_url = "http://localhost:8080/api/" +API_URL = "http://localhost:8080/api/" +PING_EVENT_DATA = { + "event": "ping", +} -def test_ping_incoming_webhook(client: TestClient) -> None: - data = { - "event": "ping", - } - signature = generate_cvat_signature(data) +def test_ping_incoming_webhook(client: TestClient) -> None: # Should respond with 200 status to a "ping" event response = client.post( "/cvat-webhook", - headers={"X-Signature-256": signature}, - json=data, + headers={"X-Signature-256": generate_cvat_signature(PING_EVENT_DATA)}, + json=PING_EVENT_DATA, ) assert response.status_code == 200 @@ -36,13 +38,13 @@ def test_ping_incoming_webhook(client: TestClient) -> None: def test_incoming_webhook_200(client: TestClient) -> None: # Create some entities in test DB add_cvat_project_to_db(cvat_id=1) - add_cvat_task_to_db(cvat_id=1, cvat_project_id=1, status="annotation") + add_cvat_task_to_db(cvat_id=1, cvat_project_id=1) # Payload for "create:job" event data = { "event": "create:job", "job": { - "url": api_url + "jobs/1", + "url": API_URL + "jobs/1", "id": 1, "task_id": 1, "project_id": 1, @@ -71,19 +73,30 @@ def test_incoming_webhook_200(client: TestClient) -> None: assert job.cvat_project_id == 1 -def test_incoming_webhook_200_update_expired_assignmets(client: TestClient) -> None: +@pytest.mark.parametrize("is_last_assignment", [True, False]) +def test_incoming_webhook_can_update_expired_assignment( + client: TestClient, is_last_assignment: bool +): + # Check if an "update:job" event can update an expired assignment, + # if the assignment is the last one for the job. Updates to other assignments should be ignored. + add_cvat_project_to_db(cvat_id=1) - add_cvat_task_to_db(cvat_id=1, cvat_project_id=1, status="annotation") - add_cvat_job_to_db(cvat_id=1, cvat_task_id=1, cvat_project_id=1, status="new") - (job, _) = get_cvat_job_from_db(1) - # Check if "update:job" event works with expired assignments - wallet_address = "0x86e83d346041E8806e352681f3F14549C0d2BC68" - add_asignment_to_db(wallet_address, 1, job.cvat_id, datetime.now(tz=timezone.utc)) + add_cvat_task_to_db(cvat_id=1, cvat_project_id=1) + job = add_cvat_job_to_db( + cvat_id=1, cvat_task_id=1, cvat_project_id=1, status=JobStatuses.in_progress + ) + + user_cvat_id = 1 + add_assignment_to_db(WALLET_ADDRESS1, user_cvat_id, job.cvat_id, expires_at=utcnow()) + + if not is_last_assignment: + user_cvat_id += 1 + add_assignment_to_db(WALLET_ADDRESS2, user_cvat_id, job.cvat_id, expires_at=utcnow()) data = { "event": "update:job", "job": { - "url": api_url + "jobs/1", + "url": API_URL + "jobs/1", "id": 1, "task_id": 1, "project_id": 1, @@ -91,44 +104,57 @@ def test_incoming_webhook_200_update_expired_assignmets(client: TestClient) -> N "start_frame": 0, "stop_frame": 1, "assignee": { - "url": api_url + "users/1", - "id": 1, + "url": API_URL + f"users/{user_cvat_id}", + "id": user_cvat_id, }, - "updated_date": datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%S.%f") + "Z", + "updated_date": (utcnow() + timedelta(hours=1)).strftime("%Y-%m-%dT%H:%M:%S.%f") + "Z", }, "before_update": {"state": "new", "assignee": None}, "webhook_id": 1, } - signature = generate_cvat_signature(data) - - with patch("src.handlers.cvat_events.cvat_api"): + with patch("src.handlers.cvat_events.cvat_api.update_job_assignee") as mock_update_job_assignee: response = client.post( "/cvat-webhook", - headers={"X-Signature-256": signature}, + headers={"X-Signature-256": generate_cvat_signature(data)}, json=data, ) assert response.status_code == 200 - (job, asignees) = get_cvat_job_from_db(1) - assert job.status == JobStatuses.new.value - assert asignees[0].status == AssignmentStatuses.expired.value + (job, assignments) = get_cvat_job_from_db(1) + assert job.status == JobStatuses.new + assert assignments[-1].status == AssignmentStatuses.expired + mock_update_job_assignee.assert_called_once_with(job.cvat_id, assignee_id=None) + + if not is_last_assignment: + for assignment in assignments[:-1]: + assert assignment.status == AssignmentStatuses.created -def test_incoming_webhook_200_update(client: TestClient) -> None: +@pytest.mark.parametrize("assignment_status", AssignmentStatuses) +def test_incoming_webhook_can_update_active_assignment( + client: TestClient, assignment_status: AssignmentStatuses +): add_cvat_project_to_db(cvat_id=1) - add_cvat_task_to_db(cvat_id=1, cvat_project_id=1, status="annotation") - add_cvat_job_to_db(cvat_id=1, cvat_task_id=1, cvat_project_id=1, status="new") - (job, _) = get_cvat_job_from_db(1) - # Check if "update:job" event works correctly - wallet_address = "0x86e83d346041E8806e352681f3F14549C0d2BC69" - add_asignment_to_db(wallet_address, 2, job.cvat_id, datetime.now() + timedelta(hours=1)) + add_cvat_task_to_db(cvat_id=1, cvat_project_id=1) + job = add_cvat_job_to_db( + cvat_id=1, cvat_task_id=1, cvat_project_id=1, status=JobStatuses.in_progress + ) + add_assignment_to_db( + WALLET_ADDRESS1, + 1, + job.cvat_id, + status=assignment_status, + expires_at=datetime.now() + if assignment_status == AssignmentStatuses.expired + else datetime.now() + timedelta(hours=1), + ) data = { "event": "update:job", "job": { - "url": api_url + "jobs/1", + "url": API_URL + "jobs/1", "id": 1, "task_id": 1, "project_id": 1, @@ -136,34 +162,33 @@ def test_incoming_webhook_200_update(client: TestClient) -> None: "start_frame": 0, "stop_frame": 1, "assignee": { - "url": api_url + "users/1", - "id": 2, + "url": API_URL + "users/1", + "id": 1, }, - "updated_date": datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%S.%f") + "Z", + "updated_date": utcnow().strftime("%Y-%m-%dT%H:%M:%S.%f") + "Z", }, - "before_update": {"state": "new", "assignee": None}, + "before_update": {"state": "in_progress", "assignee": None}, "webhook_id": 1, } - signature = generate_cvat_signature(data) - - with patch("src.handlers.cvat_events.cvat_api"): + with patch("src.handlers.cvat_events.cvat_api.update_job_assignee") as mock_update_job_assignee: response = client.post( "/cvat-webhook", - headers={"X-Signature-256": signature}, + headers={"X-Signature-256": generate_cvat_signature(data)}, json=data, ) assert response.status_code == 200 - (job, asignees) = get_cvat_job_from_db(1) - assert job.status == JobStatuses.completed.value - assert asignees[0].status == AssignmentStatuses.completed.value - - -data = { - "event": "ping", -} + (job, assignments) = get_cvat_job_from_db(1) + if assignment_status == AssignmentStatuses.created: + assert job.status == JobStatuses.completed + assert assignments[0].status == AssignmentStatuses.completed + mock_update_job_assignee.assert_called_once_with(job.cvat_id, assignee_id=None) + else: + assert job.status == JobStatuses.in_progress + assert assignments[0].status == assignment_status + mock_update_job_assignee.assert_not_called() def test_incoming_webhook_401_bad_signature(client: TestClient) -> None: @@ -171,7 +196,7 @@ def test_incoming_webhook_401_bad_signature(client: TestClient) -> None: response = client.post( "/cvat-webhook", headers={"X-Signature-256": "dummy_signature"}, - json=data, + json=PING_EVENT_DATA, ) assert response.status_code == 401 assert response.json() == {"message": "Unauthorized"} @@ -180,7 +205,7 @@ def test_incoming_webhook_401_bad_signature(client: TestClient) -> None: def test_incoming_webhook_401_without_signature(client: TestClient) -> None: response = client.post( "/cvat-webhook", - json=data, + json=PING_EVENT_DATA, ) # Send a request without a signature diff --git a/packages/examples/cvat/exchange-oracle/tests/api/test_exchange_api.py b/packages/examples/cvat/exchange-oracle/tests/api/test_exchange_api.py index ef6e0b02c4..a310398829 100644 --- a/packages/examples/cvat/exchange-oracle/tests/api/test_exchange_api.py +++ b/packages/examples/cvat/exchange-oracle/tests/api/test_exchange_api.py @@ -15,13 +15,14 @@ from sqlalchemy.orm import Session from src.core.config import Config -from src.core.types import AssignmentStatuses, ProjectStatuses, TaskTypes +from src.core.types import AssignmentStatuses, JobStatuses, ProjectStatuses, TaskTypes from src.models.cvat import Assignment, Job, Project, Task, User from src.schemas.exchange import AssignmentStatuses as APIAssignmentStatuses from src.schemas.exchange import JobStatuses as APIJobStatuses -from src.services import cvat +from src.services import cvat as cvat_service from src.utils.time import utcnow +from tests.utils.constants import WALLET_ADDRESS1 from tests.utils.db_helper import ( create_job, create_project, @@ -30,8 +31,6 @@ create_task, ) -escrow_address = "0x12E66A452f95bff49eD5a30b0d06Ebc37C5A94B6" -user_address = "0x86e83d346041E8806e352681f3F14549C0d2BC60" cvat_email = "test@hmt.ai" @@ -68,7 +67,7 @@ def generate_jwt_token( return jwt.encode(data, private_key, algorithm="ES256") -def get_auth_header(token: str = generate_jwt_token(wallet_address=user_address)) -> dict: +def get_auth_header(token: str = generate_jwt_token(wallet_address=WALLET_ADDRESS1)) -> dict: return {"Authorization": f"Bearer {token}"} @@ -148,7 +147,7 @@ def validate_result( session.begin() user = User( - wallet_address=user_address, + wallet_address=WALLET_ADDRESS1, cvat_email=cvat_email, cvat_id=1, ) @@ -162,7 +161,7 @@ def validate_result( ) assignment = Assignment( id=str(uuid.uuid4()), - user_wallet_address=user_address, + user_wallet_address=WALLET_ADDRESS1, cvat_job_id=cvat_job.cvat_id, expires_at=utcnow() + timedelta(days=1), ) @@ -208,7 +207,7 @@ def test_can_list_jobs_200_without_escrows_in_hidden_states( ) -> None: session.begin() user = User( - wallet_address=user_address, + wallet_address=WALLET_ADDRESS1, cvat_email=cvat_email, cvat_id=1, ) @@ -271,7 +270,7 @@ def test_can_list_jobs_200_with_only_one_entry_per_escrow_address_if_several_pro session.begin() user = User( - wallet_address=user_address, + wallet_address=WALLET_ADDRESS1, cvat_email=cvat_email, cvat_id=1, ) @@ -303,7 +302,7 @@ def test_can_list_jobs_200_with_only_one_entry_per_escrow_address_if_several_pro def test_can_list_jobs_200_with_fields(client: TestClient, session: Session) -> None: session.begin() user = User( - wallet_address=user_address, + wallet_address=WALLET_ADDRESS1, cvat_email=cvat_email, cvat_id=1, ) @@ -314,7 +313,7 @@ def test_can_list_jobs_200_with_fields(client: TestClient, session: Session) -> ) assignment = Assignment( id=str(uuid.uuid4()), - user_wallet_address=user_address, + user_wallet_address=WALLET_ADDRESS1, cvat_job_id=cvat_job.cvat_id, expires_at=utcnow() + timedelta(days=1), ) @@ -361,7 +360,7 @@ def test_can_list_jobs_200_with_sorting(client: TestClient, session: Session) -> # sort: ASC, DESC; sort_field: chain_id|job_type|created_at|updated_at session.begin() user = User( - wallet_address=user_address, + wallet_address=WALLET_ADDRESS1, cvat_email=cvat_email, cvat_id=1, ) @@ -375,14 +374,14 @@ def test_can_list_jobs_200_with_sorting(client: TestClient, session: Session) -> cvat_project, cvat_task, cvat_job = create_project_task_and_job( session, f"0x86e83d346041E8806e352681f3F14549C0d2BC6{i}", i + 1 ) - cvat.touch(session, Job, [cvat_job.id]) + cvat_service.touch(session, Job, [cvat_job.id]) cvat_projects.append(cvat_project) cvat_tasks.append(cvat_task) cvat_jobs.append(cvat_job) assignment = Assignment( id=str(uuid.uuid4()), - user_wallet_address=user_address, + user_wallet_address=WALLET_ADDRESS1, cvat_job_id=cvat_job.cvat_id, expires_at=utcnow() + timedelta(hours=i + 1), status=AssignmentStatuses.created if i % 2 else AssignmentStatuses.completed, @@ -392,7 +391,7 @@ def test_can_list_jobs_200_with_sorting(client: TestClient, session: Session) -> session.commit() last_updated_job = cvat_jobs[1] - cvat.touch(session, Job, [last_updated_job.id]) + cvat_service.touch(session, Job, [last_updated_job.id]) session.commit() assert { @@ -444,7 +443,7 @@ def test_can_list_jobs_200_with_sorting(client: TestClient, session: Session) -> def test_can_list_jobs_200_with_filters(client: TestClient, session: Session): session.begin() user_1 = User( - wallet_address=user_address, + wallet_address=WALLET_ADDRESS1, cvat_email=cvat_email, cvat_id=1, ) @@ -490,7 +489,7 @@ def test_can_list_jobs_200_with_filters(client: TestClient, session: Session): session.add(assignment) assignments.append(assignment) - cvat.touch(session, Job, [cvat_job.id]) + cvat_service.touch(session, Job, [cvat_job.id]) session.commit() # TODO: imitate different created_dates visible_projects_ids = set( @@ -510,7 +509,7 @@ def test_can_list_jobs_200_with_filters(client: TestClient, session: Session): updated_cvat_project_ids = set() for job in cvat_jobs[len(cvat_jobs) // 2 :]: - cvat.touch(session, Job, [job.id]) + cvat_service.touch(session, Job, [job.id]) updated_cvat_project_ids.add(job.task.cvat_project_id) session.commit() @@ -568,7 +567,7 @@ def test_can_list_jobs_200_with_filters(client: TestClient, session: Session): def test_can_list_jobs_200_check_values(client: TestClient, session: Session) -> None: session.begin() user = User( - wallet_address=user_address, + wallet_address=WALLET_ADDRESS1, cvat_email=cvat_email, cvat_id=1, ) @@ -584,12 +583,12 @@ def test_can_list_jobs_200_check_values(client: TestClient, session: Session) -> for job in (cvat_second_job, cvat_first_job): assignment = Assignment( id=str(uuid.uuid4()), - user_wallet_address=user_address, + user_wallet_address=WALLET_ADDRESS1, cvat_job_id=job.cvat_id, expires_at=utcnow() + timedelta(days=1), ) session.add(assignment) - cvat.touch(session, Job, [job.id]) + cvat_service.touch(session, Job, [job.id]) session.commit() with ( @@ -631,7 +630,7 @@ def test_can_list_jobs_200_without_address(client: TestClient, session: Session) create_project_task_and_job(session, "0x86e83d346041E8806e352681f3F14549C0d2BC69", 3) user = User( - wallet_address=user_address, + wallet_address=WALLET_ADDRESS1, cvat_email=cvat_email, cvat_id=1, ) @@ -639,7 +638,7 @@ def test_can_list_jobs_200_without_address(client: TestClient, session: Session) assignment = Assignment( id=str(uuid.uuid4()), - user_wallet_address=user_address, + user_wallet_address=WALLET_ADDRESS1, cvat_job_id=cvat_job_1.cvat_id, expires_at=utcnow() + timedelta(days=1), ) @@ -705,15 +704,15 @@ def test_can_register_200(client: TestClient, session: Session) -> None: assert response.status_code == 200 user = response.json() - db_user = session.query(User).where(User.wallet_address == user_address).first() - assert user["wallet_address"] == db_user.wallet_address == user_address + db_user = session.query(User).where(User.wallet_address == WALLET_ADDRESS1).first() + assert user["wallet_address"] == db_user.wallet_address == WALLET_ADDRESS1 assert user["email"] == db_user.cvat_email == cvat_email def test_cannot_register_400_with_duplicated_address(client: TestClient, session: Session) -> None: session.begin() user = User( - wallet_address=user_address, + wallet_address=WALLET_ADDRESS1, cvat_email=cvat_email, cvat_id=1, ) @@ -724,7 +723,7 @@ def test_cannot_register_400_with_duplicated_address(client: TestClient, session response = client.post( "/register", headers=get_auth_header( - generate_jwt_token(wallet_address=user_address, email=new_cvat_email) + generate_jwt_token(wallet_address=WALLET_ADDRESS1, email=new_cvat_email) ), ) assert response.status_code == 400 @@ -733,9 +732,9 @@ def test_cannot_register_400_with_duplicated_address(client: TestClient, session def test_cannot_register_400_with_duplicated_user(client: TestClient, session: Session) -> None: session.begin() - new_user_address = "0x86e83d346041E8806e352681f3F14549C0d2BC61" + new_WALLET_ADDRESS1 = "0x86e83d346041E8806e352681f3F14549C0d2BC61" user = User( - wallet_address=new_user_address, + wallet_address=new_WALLET_ADDRESS1, cvat_email=cvat_email, cvat_id=1, ) @@ -747,7 +746,7 @@ def test_cannot_register_400_with_duplicated_user(client: TestClient, session: S ) assert response.status_code == 400 assert response.json() == {"message": "User already exists"} - assert new_user_address != user_address + assert new_WALLET_ADDRESS1 != WALLET_ADDRESS1 def test_cannot_register_401(client: TestClient) -> None: @@ -771,12 +770,13 @@ def test_can_create_assignment_200(client: TestClient, session: Session) -> None session, "0x86e83d346041E8806e352681f3F14549C0d2BC67", 1 ) user = User( - wallet_address=user_address, + wallet_address=WALLET_ADDRESS1, cvat_email=cvat_email, cvat_id=1, ) session.add(user) session.commit() + with ( open("tests/utils/manifest.json") as data, patch("src.endpoints.serializers.get_escrow_manifest") as mock_get_manifest, @@ -815,7 +815,7 @@ def test_can_create_assignment_200(client: TestClient, session: Session) -> None } db_assignment = ( - session.query(Assignment).filter_by(user_wallet_address=user_address).first() + session.query(Assignment).filter_by(user_wallet_address=WALLET_ADDRESS1).first() ) assert assignment["escrow_address"] == cvat_project.escrow_address assert assignment["chain_id"] == cvat_project.chain_id @@ -841,7 +841,7 @@ def test_cannot_create_assignment_401(client: TestClient) -> None: response = client.post( "/assignment", headers=get_auth_header(token) if token else None, - json={"wallet_address": user_address, "cvat_email": cvat_email}, + json={"wallet_address": WALLET_ADDRESS1, "cvat_email": cvat_email}, ) assert response.status_code == 401 @@ -859,7 +859,7 @@ def test_cannot_create_assignment_400_when_has_unfinished_assignments( create_job(session, 2, cvat_task.cvat_id, cvat_project.cvat_id) user = User( - wallet_address=user_address, + wallet_address=WALLET_ADDRESS1, cvat_email=cvat_email, cvat_id=1, ) @@ -868,7 +868,7 @@ def test_cannot_create_assignment_400_when_has_unfinished_assignments( assignment = Assignment( created_at=utcnow(), expires_at=utcnow() + timedelta(hours=1), - user_wallet_address=user_address, + user_wallet_address=WALLET_ADDRESS1, cvat_job_id=cvat_job1.cvat_id, status=AssignmentStatuses.created.value, ) @@ -892,7 +892,7 @@ def test_cannot_create_assignment_400_when_has_unfinished_assignments( def test_can_list_assignments_200(client: TestClient, session: Session) -> None: session.begin() user_1 = User( - wallet_address=user_address, + wallet_address=WALLET_ADDRESS1, cvat_email=cvat_email, cvat_id=1, ) @@ -1020,7 +1020,7 @@ def test_can_list_assignments_200_with_sorting(client: TestClient, session: Sess # sort: ASC, DESC session.begin() user = User( - wallet_address=user_address, + wallet_address=WALLET_ADDRESS1, cvat_email=cvat_email, cvat_id=1, ) @@ -1033,7 +1033,7 @@ def test_can_list_assignments_200_with_sorting(client: TestClient, session: Sess assignment = Assignment( id=str(uuid.uuid4()), - user_wallet_address=user_address, + user_wallet_address=WALLET_ADDRESS1, cvat_job_id=cvat_job.cvat_id, expires_at=utcnow() + timedelta(hours=i + 1), status=AssignmentStatuses.created if i % 2 else AssignmentStatuses.completed, @@ -1086,33 +1086,42 @@ def test_can_resign_assignment_200(client: TestClient, session: Session) -> None cvat_project, cvat_task, cvat_job = create_project_task_and_job( session, "0x86e83d346041E8806e352681f3F14549C0d2BC67", 1 ) + cvat_job.status = JobStatuses.in_progress + cvat_job.updated_at = None + user = User( - wallet_address=user_address, + wallet_address=WALLET_ADDRESS1, cvat_email=cvat_email, cvat_id=1, ) assignment = Assignment( id=str(uuid.uuid4()), - user_wallet_address=user_address, + user_wallet_address=WALLET_ADDRESS1, cvat_job_id=cvat_job.cvat_id, expires_at=utcnow() + timedelta(hours=1), + status=AssignmentStatuses.created, ) - session.add_all([user, assignment]) + session.add_all([cvat_job, user, assignment]) + session.commit() assert {cvat_job.updated_at, cvat_task.updated_at, cvat_job.updated_at} == {None} - response = client.post( - "/assignment/resign", - headers=get_auth_header(), - json={"assignment_id": assignment.id}, - ) + + with patch("src.services.exchange.cvat_api.update_job_assignee") as mock_update_job_assignee: + response = client.post( + "/assignment/resign", + headers=get_auth_header(), + json={"assignment_id": assignment.id}, + ) + + mock_update_job_assignee.assert_called_once_with(cvat_job.cvat_id, assignee_id=None) assert response.status_code == 200 session.refresh(assignment) assert assignment.status == AssignmentStatuses.canceled - for obj in cvat_project, cvat_task, cvat_job: + for obj in (cvat_project, cvat_task, cvat_job): session.refresh(obj) assert obj.updated_at is not None assert cvat_project.updated_at == cvat_task.updated_at == cvat_job.updated_at @@ -1144,14 +1153,14 @@ def test_cannot_resign_assignment_400_when_assignment_is_finished( session, "0x86e83d346041E8806e352681f3F14549C0d2BC67", 1 ) user = User( - wallet_address=user_address, + wallet_address=WALLET_ADDRESS1, cvat_email=cvat_email, cvat_id=1, ) session.add(user) assignment = Assignment( id=str(uuid.uuid4()), - user_wallet_address=user_address, + user_wallet_address=WALLET_ADDRESS1, cvat_job_id=cvat_job_1.cvat_id, expires_at=utcnow() + timedelta(hours=1), status=AssignmentStatuses.completed.value, @@ -1229,7 +1238,7 @@ def test_can_get_assignment_stats_by_worker_200(client: TestClient, session: Ses cvat_jobs.append(cvat_job) user = User( - wallet_address=user_address, + wallet_address=WALLET_ADDRESS1, cvat_email=cvat_email, cvat_id=1, ) @@ -1245,7 +1254,7 @@ def test_can_get_assignment_stats_by_worker_200(client: TestClient, session: Ses ): assignment = Assignment( id=str(uuid.uuid4()), - user_wallet_address=user_address, + user_wallet_address=WALLET_ADDRESS1, cvat_job_id=cvat_jobs[i].cvat_id, expires_at=utcnow() + timedelta(hours=1), status=status, @@ -1376,7 +1385,6 @@ def test_can_list_jobs_200_check_updated_at(client: TestClient, session: Session ] session.add_all(users) - utcnow() cvat_project = create_project(session, "0x86e83d346041E8806e352681f3F14549C0d2BC66", 1) cvat_tasks: list[Task] = [] cvat_jobs: list[Job] = [] diff --git a/packages/examples/cvat/exchange-oracle/tests/integration/cron/state_trackers/test_track_assignments.py b/packages/examples/cvat/exchange-oracle/tests/integration/cron/state_trackers/test_track_assignments.py index 2acae81df7..b4b2679dc1 100644 --- a/packages/examples/cvat/exchange-oracle/tests/integration/cron/state_trackers/test_track_assignments.py +++ b/packages/examples/cvat/exchange-oracle/tests/integration/cron/state_trackers/test_track_assignments.py @@ -3,17 +3,15 @@ from datetime import datetime, timedelta from unittest.mock import patch -import pytest -from sqlalchemy import update - from src.core.types import ( AssignmentStatuses, - ProjectStatuses, + JobStatuses, ) from src.crons.cvat.state_trackers import track_assignments from src.db import SessionLocal -from src.models.cvat import Assignment, Project, User +from src.models.cvat import Assignment, Job, User +from tests.utils.constants import ESCROW_ADDRESS, WALLET_ADDRESS1, WALLET_ADDRESS2 from tests.utils.db_helper import create_project_task_and_job @@ -24,128 +22,118 @@ def setUp(self): def tearDown(self): self.session.close() - def test_track_expired_assignments(self): - (_, _, cvat_job) = create_project_task_and_job( - self.session, "0x86e83d346041E8806e352681f3F14549C0d2BC67", 1 - ) - wallet_address_1 = "0x86e83d346041E8806e352681f3F14549C0d2BC67" + def test_can_track_expired_assignments(self): + (_, _, cvat_job) = create_project_task_and_job(self.session, ESCROW_ADDRESS, 1) + cvat_job.status = JobStatuses.in_progress + self.session.add(cvat_job) + user = User( - wallet_address=wallet_address_1, + wallet_address=WALLET_ADDRESS1, cvat_email="test@hmt.ai", cvat_id=1, ) self.session.add(user) - wallet_address_2 = "0x86e83d346041E8806e352681f3F14549C0d2BC68" user = User( - wallet_address=wallet_address_2, + wallet_address=WALLET_ADDRESS2, cvat_email="test2@hmt.ai", cvat_id=2, ) self.session.add(user) - assignment = Assignment( + + assignment1 = Assignment( id=str(uuid.uuid4()), - user_wallet_address=wallet_address_1, + user_wallet_address=WALLET_ADDRESS1, cvat_job_id=cvat_job.cvat_id, - expires_at=datetime.now() + timedelta(days=1), + created_at=datetime.now() - timedelta(hours=2), + expires_at=datetime.now() - timedelta(hours=1), + status=AssignmentStatuses.created, ) - assignment_2 = Assignment( + self.session.add(assignment1) + + assignment2 = Assignment( id=str(uuid.uuid4()), - user_wallet_address=wallet_address_2, + user_wallet_address=WALLET_ADDRESS2, cvat_job_id=cvat_job.cvat_id, - expires_at=datetime.now() - timedelta(days=1), - created_at=datetime.now() + timedelta(hours=1), + created_at=datetime.now() - timedelta(hours=1), + expires_at=datetime.now(), + status=AssignmentStatuses.created, ) - self.session.add(assignment) - self.session.add(assignment_2) - self.session.commit() + self.session.add(assignment2) - db_assignments = sorted( - self.session.query(Assignment).all(), key=lambda assignment: assignment.user.cvat_id - ) - assert db_assignments[0].status == AssignmentStatuses.created.value - assert db_assignments[1].status == AssignmentStatuses.created.value + self.session.commit() - with patch("src.crons.cvat.state_trackers.cvat_api.update_job_assignee") as mock_cvat_api: + with patch( + "src.crons.cvat.state_trackers.cvat_api.update_job_assignee" + ) as update_job_assignee: track_assignments() - mock_cvat_api.assert_called_once_with(assignment_2.cvat_job_id, assignee_id=None) - self.session.commit() + update_job_assignee.assert_called_once_with(assignment2.cvat_job_id, assignee_id=None) db_assignments = sorted( self.session.query(Assignment).all(), key=lambda assignment: assignment.user.cvat_id ) - assert db_assignments[0].status == AssignmentStatuses.created.value - assert db_assignments[1].status == AssignmentStatuses.expired.value - - @pytest.mark.xfail( - strict=True, - reason=""" -Fix src.crons.cvat.state_trackers.py -Where in `cvat_service.get_active_assignments()` return value will be empty -because it actually looking for the expired assignments -""", - ) - def test_track_canceled_assignments(self): - (_, _, cvat_job) = create_project_task_and_job( - self.session, "0x86e83d346041E8806e352681f3F14549C0d2BC67", 1 - ) - (cvat_project_2, _, cvat_job_2) = create_project_task_and_job( - self.session, "0x86e83d346041E8806e352681f3F14549C0d2BC68", 2 + assert db_assignments[0].status == AssignmentStatuses.expired + assert db_assignments[1].status == AssignmentStatuses.expired + + assert ( + self.session.query(Job).filter(Job.id == cvat_job.id).first().status == JobStatuses.new ) - wallet_address_1 = "0x86e83d346041E8806e352681f3F14549C0d2BC67" + + def test_can_track_canceled_assignments(self): + (_, _, cvat_job) = create_project_task_and_job(self.session, ESCROW_ADDRESS, 1) + cvat_job.status = JobStatuses.in_progress + self.session.add(cvat_job) + user = User( - wallet_address=wallet_address_1, + wallet_address=WALLET_ADDRESS1, cvat_email="test@hmt.ai", cvat_id=1, ) self.session.add(user) - wallet_address_2 = "0x86e83d346041E8806e352681f3F14549C0d2BC68" user = User( - wallet_address=wallet_address_2, + wallet_address=WALLET_ADDRESS2, cvat_email="test2@hmt.ai", cvat_id=2, ) self.session.add(user) - assignment = Assignment( + + assignment1 = Assignment( id=str(uuid.uuid4()), - user_wallet_address=wallet_address_1, + user_wallet_address=WALLET_ADDRESS1, cvat_job_id=cvat_job.cvat_id, - expires_at=datetime.now() + timedelta(days=1), - ) - assignment_2 = Assignment( - id=str(uuid.uuid4()), - user_wallet_address=wallet_address_2, - cvat_job_id=cvat_job_2.cvat_id, - expires_at=datetime.now() + timedelta(days=1), - created_at=datetime.now() + timedelta(hours=1), + created_at=datetime.now() - timedelta(hours=2), + expires_at=datetime.now() - timedelta(hours=1), + status=AssignmentStatuses.canceled, ) - self.session.add(assignment) - self.session.add(assignment_2) + self.session.add(assignment1) - self.session.execute( - update(Project) - .where(Project.id == cvat_project_2.id) - .values(status=ProjectStatuses.completed.value) + assignment2 = Assignment( + id=str(uuid.uuid4()), + user_wallet_address=WALLET_ADDRESS2, + cvat_job_id=cvat_job.cvat_id, + created_at=datetime.now() - timedelta(hours=1), + expires_at=datetime.now() + timedelta(hours=1), + status=AssignmentStatuses.canceled, ) + self.session.add(assignment2) self.session.commit() - db_assignments = sorted( - self.session.query(Assignment).all(), key=lambda assignment: assignment.user.cvat_id - ) - assert db_assignments[0].status == AssignmentStatuses.created.value - assert db_assignments[1].status == AssignmentStatuses.created.value - - with patch("src.crons.cvat.state_trackers.cvat_api.update_job_assignee") as mock_cvat_api: + with patch( + "src.crons.cvat.state_trackers.cvat_api.update_job_assignee" + ) as update_job_assignee: track_assignments() - mock_cvat_api.assert_called_once_with(assignment_2.cvat_job_id, assignee_id=None) - self.session.commit() + update_job_assignee.assert_called_once_with(assignment2.cvat_job_id, assignee_id=None) db_assignments = sorted( self.session.query(Assignment).all(), key=lambda assignment: assignment.user.cvat_id ) - assert db_assignments[0].status == AssignmentStatuses.created.value - assert db_assignments[1].status == AssignmentStatuses.canceled.value + assert db_assignments[0].status == AssignmentStatuses.canceled + assert db_assignments[1].status == AssignmentStatuses.canceled + + assert ( + self.session.query(Job).filter(Job.id == cvat_job.id).first().status == JobStatuses.new + ) diff --git a/packages/examples/cvat/exchange-oracle/tests/integration/services/test_cvat.py b/packages/examples/cvat/exchange-oracle/tests/integration/services/test_cvat.py index 0863ae43d6..f5dd8a6ca1 100644 --- a/packages/examples/cvat/exchange-oracle/tests/integration/services/test_cvat.py +++ b/packages/examples/cvat/exchange-oracle/tests/integration/services/test_cvat.py @@ -1,4 +1,3 @@ -import unittest import uuid from datetime import datetime, timedelta @@ -17,7 +16,9 @@ ) from src.db import SessionLocal from src.models.cvat import Assignment, DataUpload, Image, Job, Project, Task, User +from src.utils.time import utcnow +from tests.utils.constants import WALLET_ADDRESS1, WALLET_ADDRESS2 from tests.utils.db_helper import ( create_project, create_project_and_task, @@ -25,12 +26,18 @@ ) -class ServiceIntegrationTest(unittest.TestCase): +class ServiceIntegrationTest: + @pytest.fixture(autouse=True) def setUp(self): self.session = SessionLocal() - def tearDown(self): - self.session.close() + try: + self.session.begin() + + yield + finally: + self.session.rollback() + self.session.close() def test_create_project(self): cvat_id = 1 @@ -346,6 +353,127 @@ def test_get_projects_by_status(self): assert len(projects) == 1 + def test_can_get_free_job_if_exists(self): + escrow_address = "0x86e83d346041E8806e352681f3F14549C0d2BC67" + + (cvat_project, cvat_task, cvat_job) = create_project_task_and_job( + self.session, escrow_address, cvat_id=1 + ) + chain_id = cvat_project.chain_id + + user = User(wallet_address=WALLET_ADDRESS1, cvat_email="test1@hmt.ai", cvat_id=1) + self.session.add(user) + + self.session.commit() + + free_job = cvat_service.get_free_job( + self.session, escrow_address, chain_id, user_wallet_address=WALLET_ADDRESS1 + ) + assert free_job.id == cvat_job.id + + def test_cannot_get_free_job_if_all_completed_and_not_project_checked_yet(self): + escrow_address = "0x86e83d346041E8806e352681f3F14549C0d2BC67" + + (cvat_project, cvat_task, cvat_job) = create_project_task_and_job( + self.session, escrow_address, cvat_id=1 + ) + chain_id = cvat_project.chain_id + + cvat_job.status = JobStatuses.completed.value + cvat_job.updated_at = utcnow() + self.session.add(cvat_job) + + user1 = User(wallet_address=WALLET_ADDRESS1, cvat_email="test1@hmt.ai", cvat_id=1) + self.session.add(user1) + + user2 = User(wallet_address=WALLET_ADDRESS2, cvat_email="test2@hmt.ai", cvat_id=2) + self.session.add(user2) + + assignment = Assignment( + id=str(uuid.uuid4()), + user_wallet_address=WALLET_ADDRESS2, + cvat_job_id=cvat_job.cvat_id, + expires_at=utcnow() + timedelta(days=1), + completed_at=utcnow(), + status=AssignmentStatuses.completed.value, + ) + self.session.add(assignment) + + self.session.commit() + + free_job = cvat_service.get_free_job( + self.session, escrow_address, chain_id, user_wallet_address=WALLET_ADDRESS1 + ) + assert free_job is None + + @pytest.mark.parametrize("previous_assignment_status", AssignmentStatuses) + def test_cannot_get_free_job_if_was_assigned_to_this_user( + self, previous_assignment_status: AssignmentStatuses + ): + escrow_address = "0x86e83d346041E8806e352681f3F14549C0d2BC67" + + (cvat_project, _, cvat_job) = create_project_task_and_job( + self.session, escrow_address, cvat_id=1 + ) + chain_id = cvat_project.chain_id + + user1 = User(wallet_address=WALLET_ADDRESS1, cvat_email="test1@hmt.ai", cvat_id=1) + self.session.add(user1) + + user2 = User(wallet_address=WALLET_ADDRESS2, cvat_email="test2@hmt.ai", cvat_id=2) + self.session.add(user2) + + assignment = Assignment( + id=str(uuid.uuid4()), + user_wallet_address=WALLET_ADDRESS1, + cvat_job_id=cvat_job.cvat_id, + expires_at=utcnow() + timedelta(days=1), + status=previous_assignment_status.value, + ) + if previous_assignment_status == AssignmentStatuses.completed: + assignment.completed_at = utcnow() + self.session.add(assignment) + + self.session.commit() + + free_job = cvat_service.get_free_job( + self.session, escrow_address, chain_id, user_wallet_address=WALLET_ADDRESS1 + ) + assert free_job is None + + def test_cannot_get_free_job_if_assigned_to_other_user(self): + escrow_address = "0x86e83d346041E8806e352681f3F14549C0d2BC67" + + (cvat_project, _, cvat_job) = create_project_task_and_job( + self.session, escrow_address, cvat_id=1 + ) + chain_id = cvat_project.chain_id + + cvat_job.status = JobStatuses.in_progress + self.session.add(cvat_job) + + user1 = User(wallet_address=WALLET_ADDRESS1, cvat_email="test1@hmt.ai", cvat_id=1) + self.session.add(user1) + + user2 = User(wallet_address=WALLET_ADDRESS2, cvat_email="test2@hmt.ai", cvat_id=2) + self.session.add(user2) + + assignment = Assignment( + id=str(uuid.uuid4()), + user_wallet_address=WALLET_ADDRESS2, + cvat_job_id=cvat_job.cvat_id, + expires_at=utcnow() + timedelta(days=1), + status=AssignmentStatuses.created.value, + ) + self.session.add(assignment) + + self.session.commit() + + free_job = cvat_service.get_free_job( + self.session, escrow_address, chain_id, user_wallet_address=WALLET_ADDRESS1 + ) + assert free_job is None + def test_update_project_status(self): cvat_id = 1 cvat_cloudstorage_id = 1 @@ -1215,7 +1343,7 @@ def test_create_assignment(self): session=self.session, wallet_address=wallet_address, cvat_job_id=cvat_job.cvat_id, - expires_at=datetime.now(), + expires_at=utcnow(), ) assignment_count = self.session.query(Assignment).count() @@ -1237,7 +1365,7 @@ def test_create_assignment_invalid_address(self): session=self.session, wallet_address="invalid_address", cvat_job_id=cvat_job.cvat_id, - expires_at=datetime.now(), + expires_at=utcnow(), ) with pytest.raises(IntegrityError): self.session.commit() @@ -1255,7 +1383,7 @@ def test_create_assignment_invalid_address(self): session=self.session, wallet_address=wallet_address, cvat_job_id=0, - expires_at=datetime.now(), + expires_at=utcnow(), ) with pytest.raises(IntegrityError): self.session.commit() @@ -1283,13 +1411,13 @@ def test_get_assignments_by_id(self): session=self.session, wallet_address=wallet_address_1, cvat_job_id=cvat_job.cvat_id, - expires_at=datetime.now(), + expires_at=utcnow(), ) assignment_2 = cvat_service.create_assignment( session=self.session, wallet_address=wallet_address_2, cvat_job_id=cvat_job.cvat_id, - expires_at=datetime.now(), + expires_at=utcnow(), ) self.session.commit() @@ -1329,14 +1457,14 @@ def test_get_latest_assignment_by_cvat_job_id(self): id=str(uuid.uuid4()), user_wallet_address=wallet_address_1, cvat_job_id=cvat_job.cvat_id, - expires_at=datetime.now(), - created_at=datetime.now() - timedelta(days=1), + expires_at=utcnow(), + created_at=utcnow() - timedelta(days=1), ) assignment_2 = Assignment( id=str(uuid.uuid4()), user_wallet_address=wallet_address_2, cvat_job_id=cvat_job.cvat_id, - expires_at=datetime.now(), + expires_at=utcnow(), ) self.session.add(assignment) self.session.add(assignment_2) @@ -1372,13 +1500,13 @@ def test_get_unprocessed_expired_assignments(self): id=str(uuid.uuid4()), user_wallet_address=wallet_address_1, cvat_job_id=cvat_job.cvat_id, - expires_at=datetime.now() + timedelta(days=1), + expires_at=utcnow() + timedelta(days=1), ) assignment_2 = Assignment( id=str(uuid.uuid4()), user_wallet_address=wallet_address_2, cvat_job_id=cvat_job.cvat_id, - expires_at=datetime.now() - timedelta(days=1), + expires_at=utcnow() - timedelta(days=1), ) self.session.add(assignment) self.session.add(assignment_2) @@ -1406,7 +1534,7 @@ def test_update_assignment(self): id=str(uuid.uuid4()), user_wallet_address=wallet_address_1, cvat_job_id=cvat_job.cvat_id, - expires_at=datetime.now() + timedelta(days=1), + expires_at=utcnow() + timedelta(days=1), ) self.session.add(assignment) self.session.commit() @@ -1436,7 +1564,7 @@ def test_cancel_assignment(self): id=str(uuid.uuid4()), user_wallet_address=wallet_address_1, cvat_job_id=cvat_job.cvat_id, - expires_at=datetime.now() + timedelta(days=1), + expires_at=utcnow() + timedelta(days=1), ) self.session.add(assignment) self.session.commit() @@ -1464,7 +1592,7 @@ def test_expire_assignment(self): id=str(uuid.uuid4()), user_wallet_address=wallet_address_1, cvat_job_id=cvat_job.cvat_id, - expires_at=datetime.now() + timedelta(days=1), + expires_at=utcnow() + timedelta(days=1), ) self.session.add(assignment) self.session.commit() @@ -1492,11 +1620,11 @@ def test_complete_assignment(self): id=str(uuid.uuid4()), user_wallet_address=wallet_address_1, cvat_job_id=cvat_job.cvat_id, - expires_at=datetime.now() + timedelta(days=1), + expires_at=utcnow() + timedelta(days=1), ) self.session.add(assignment) self.session.commit() - completed_date = datetime.now() + timedelta(days=1) + completed_date = utcnow() + timedelta(days=1) cvat_service.complete_assignment(self.session, assignment.id, completed_date) db_assignment = self.session.query(Assignment).filter_by(id=assignment.id).first() @@ -1528,13 +1656,13 @@ def test_test_add_project_images(self): id=str(uuid.uuid4()), user_wallet_address=wallet_address_1, cvat_job_id=cvat_job.cvat_id, - expires_at=datetime.now(), + expires_at=utcnow(), ) assignment_2 = Assignment( id=str(uuid.uuid4()), user_wallet_address=wallet_address_2, cvat_job_id=cvat_job.cvat_id, - expires_at=datetime.now(), + expires_at=utcnow(), ) self.session.add(assignment) self.session.add(assignment_2) diff --git a/packages/examples/cvat/exchange-oracle/tests/integration/services/test_exchange.py b/packages/examples/cvat/exchange-oracle/tests/integration/services/test_exchange.py index 7d074d12e9..311cd8b9b2 100644 --- a/packages/examples/cvat/exchange-oracle/tests/integration/services/test_exchange.py +++ b/packages/examples/cvat/exchange-oracle/tests/integration/services/test_exchange.py @@ -76,7 +76,13 @@ def test_serialize_task_invalid_manifest(self): serialize_job(cvat_project) def test_create_assignment(self): - cvat_project_1, _, cvat_job_1 = create_project_task_and_job(self.session, ESCROW_ADDRESS, 1) + cvat_project, cvat_task, cvat_job = create_project_task_and_job( + self.session, ESCROW_ADDRESS, 1 + ) + initial_job_updated_at = cvat_job.updated_at + initial_task_updated_at = cvat_task.updated_at + initial_project_updated_at = cvat_project.updated_at + user_address = WALLET_ADDRESS1 user = User( wallet_address=user_address, @@ -84,18 +90,28 @@ def test_create_assignment(self): cvat_id=1, ) self.session.add(user) + self.session.commit() with patch("src.services.exchange.cvat_api"): assignment_id = create_assignment( - cvat_project_1.escrow_address, Networks(cvat_project_1.chain_id), user_address + cvat_project.escrow_address, Networks(cvat_project.chain_id), user_address ) - assignment = self.session.query(Assignment).filter_by(id=assignment_id).first() + assignment = self.session.query(Assignment).filter_by(id=assignment_id).first() + + assert assignment.cvat_job_id == cvat_job.cvat_id + assert assignment.user_wallet_address == user_address + assert assignment.status == AssignmentStatuses.created + + self.session.refresh(cvat_job) + assert cvat_job.updated_at != initial_job_updated_at + + self.session.refresh(cvat_task) + assert cvat_task.updated_at != initial_task_updated_at - assert assignment.cvat_job_id == cvat_job_1.cvat_id - assert assignment.user_wallet_address == user_address - assert assignment.status == AssignmentStatuses.created + self.session.refresh(cvat_project) + assert cvat_project.updated_at != initial_project_updated_at def test_create_assignment_many_jobs_1_completed(self): cvat_project, _, cvat_job_1 = create_project_task_and_job(self.session, ESCROW_ADDRESS, 1) @@ -163,6 +179,10 @@ def test_create_assignment_invalid_project(self): def test_create_assignment_unfinished_assignment(self): _, _, cvat_job = create_project_task_and_job(self.session, ESCROW_ADDRESS, 1) + + cvat_job.status = JobStatuses.in_progress + self.session.add(cvat_job) + user_address = WALLET_ADDRESS1 user = User( wallet_address=user_address, @@ -266,6 +286,9 @@ def test_create_assignment_no_available_jobs_completed_assignment(self): def test_create_assignment_no_available_jobs_active_foreign_assignment(self): cvat_project, _, cvat_job_1 = create_project_task_and_job(self.session, ESCROW_ADDRESS, 1) + cvat_job_1.status = JobStatuses.in_progress + self.session.add(cvat_job_1) + user_address1 = WALLET_ADDRESS1 user1 = User( wallet_address=user_address1, diff --git a/packages/examples/cvat/exchange-oracle/tests/utils/setup_cvat.py b/packages/examples/cvat/exchange-oracle/tests/utils/setup_cvat.py index 8bc1a01515..77f657f98b 100644 --- a/packages/examples/cvat/exchange-oracle/tests/utils/setup_cvat.py +++ b/packages/examples/cvat/exchange-oracle/tests/utils/setup_cvat.py @@ -1,16 +1,21 @@ import hmac import json import uuid +from collections.abc import Generator, Sequence +from contextlib import ExitStack, contextmanager from datetime import datetime from hashlib import sha256 +from sqlalchemy.orm import Session from sqlalchemy.sql import select from src.core.config import CvatConfig -from src.core.types import ProjectStatuses, TaskTypes +from src.core.types import AssignmentStatuses, JobStatuses, ProjectStatuses, TaskStatuses, TaskTypes from src.db import SessionLocal from src.models.cvat import Assignment, Job, Project, Task, User +from tests.utils.constants import ESCROW_ADDRESS + def generate_cvat_signature(data: dict): b_data = json.dumps(data).encode("utf-8") @@ -25,89 +30,121 @@ def generate_cvat_signature(data: dict): ) -def add_cvat_project_to_db(cvat_id: int) -> str: - with SessionLocal.begin() as session: +def add_cvat_project_to_db(cvat_id: int, *, session: Session | None = None) -> Project: + with get_session(session) as session_: project_id = str(uuid.uuid4()) project = Project( id=project_id, cvat_id=cvat_id, cvat_cloudstorage_id=1, - status=ProjectStatuses.annotation.value, - job_type=TaskTypes.image_label_binary.value, - escrow_address="0x86e83d346041E8806e352681f3F14549C0d2BC67", + status=ProjectStatuses.annotation, + job_type=TaskTypes.image_label_binary, + escrow_address=ESCROW_ADDRESS, chain_id=80002, bucket_url="https://test.storage.googleapis.com/", ) - session.add(project) + session_.add(project) - return project_id + return project -def add_cvat_task_to_db(cvat_id: int, cvat_project_id: int, status: str) -> str: - with SessionLocal.begin() as session: +def add_cvat_task_to_db( + cvat_id: int, + cvat_project_id: int, + *, + status: TaskStatuses | str = TaskStatuses.annotation, + session: Session | None = None, +) -> Task: + with get_session(session) as session_: task_id = str(uuid.uuid4()) task = Task( id=task_id, cvat_id=cvat_id, cvat_project_id=cvat_project_id, - status=status, + status=TaskStatuses(status) if not isinstance(status, TaskStatuses) else status, ) - session.add(task) + session_.add(task) - return task_id + return task -# FUTURE-FIXME: a lot of ways to create a test job -def add_cvat_job_to_db(cvat_id: int, cvat_task_id: int, cvat_project_id: int, status: str) -> str: - with SessionLocal.begin() as session: +def add_cvat_job_to_db( + cvat_id: int, + cvat_task_id: int, + cvat_project_id: int, + *, + status: JobStatuses | str = JobStatuses.new, + session: Session | None = None, +) -> Job: + with get_session(session) as session_: job_id = str(uuid.uuid4()) job = Job( id=job_id, cvat_id=cvat_id, cvat_task_id=cvat_task_id, cvat_project_id=cvat_project_id, - status=status, + status=JobStatuses(status) if not isinstance(status, JobStatuses) else status, start_frame=0, stop_frame=1, ) - session.add(job) + session_.add(job) - return job_id + return job -def add_asignment_to_db( - wallet_address: str, cvat_id: int, cvat_job_id: int, expires_at: datetime -) -> str: - with SessionLocal.begin() as session: +def add_assignment_to_db( + wallet_address: str, + cvat_id: int, + cvat_job_id: int, + expires_at: datetime, + *, + status: AssignmentStatuses | str = AssignmentStatuses.created, + session: Session | None = None, +) -> Assignment: + with get_session(session) as session_: user = User( wallet_address=wallet_address, cvat_email="test" + str(cvat_id) + "@hmt.ai", cvat_id=cvat_id, ) - session.add(user) + session_.add(user) assignment_id = str(uuid.uuid4()) assignment = Assignment( id=assignment_id, user_wallet_address=wallet_address, cvat_job_id=cvat_job_id, expires_at=expires_at, + status=AssignmentStatuses(status) + if not isinstance(status, AssignmentStatuses) + else status, ) - session.add(assignment) + session_.add(assignment) - return assignment_id + return assignment -def get_cvat_job_from_db(cvat_id: int) -> tuple: - with SessionLocal.begin() as session: - session.expire_on_commit = False +def get_cvat_job_from_db( + cvat_id: int, *, session: Session | None = None +) -> tuple[Job, Sequence[Assignment]]: + with get_session(session) as session_: job_query = select(Job).where(Job.cvat_id == cvat_id) - job = session.execute(job_query).scalars().first() + job = session_.execute(job_query).scalars().first() + + assignments_query = select(Assignment).where(Assignment.cvat_job_id == cvat_id) + assignments = session_.execute(assignments_query).scalars().all() + + return job, assignments + - asignments_query = select(Assignment).where(Assignment.cvat_job_id == cvat_id) - asignments = session.execute(asignments_query).scalars().all() +@contextmanager +def get_session(session: Session | None = None) -> Generator[Session, None, None]: + with ExitStack() as es: + if not session: + session = es.enter_context(SessionLocal.begin()) + session.expire_on_commit = False - return job, asignments + yield session diff --git a/packages/examples/cvat/recording-oracle/Dockerfile b/packages/examples/cvat/recording-oracle/Dockerfile index 86c54f1441..69e547480d 100644 --- a/packages/examples/cvat/recording-oracle/Dockerfile +++ b/packages/examples/cvat/recording-oracle/Dockerfile @@ -4,13 +4,17 @@ WORKDIR /app RUN apt-get update -y && \ apt-get install -y jq ffmpeg libsm6 libxext6 && \ - pip install --no-cache poetry + rm -rf /var/lib/apt/lists/* + +RUN pip install --no-cache poetry COPY pyproject.toml poetry.lock ./ -RUN poetry config virtualenvs.create false \ - && poetry install --no-interaction --no-ansi --no-root \ - && poetry cache clear pypi --all +RUN --mount=type=cache,target=/root/.cache \ + poetry config virtualenvs.create false && \ + poetry install --no-interaction --no-ansi --no-root + +RUN python -m pip uninstall -y poetry pip COPY . . diff --git a/packages/examples/cvat/recording-oracle/README.MD b/packages/examples/cvat/recording-oracle/README.md similarity index 87% rename from packages/examples/cvat/recording-oracle/README.MD rename to packages/examples/cvat/recording-oracle/README.md index 55faf0c916..6b412f5081 100644 --- a/packages/examples/cvat/recording-oracle/README.MD +++ b/packages/examples/cvat/recording-oracle/README.md @@ -18,14 +18,14 @@ For deployment it is required to have PostgreSQL(v14.4) ### Run the oracle locally: -``` +```sh docker compose -f docker-compose.dev.yml up -d ./bin/start_dev.sh ``` or -``` +```sh docker compose -f docker-compose.dev.yml up -d ./bin/start_debug.sh ``` @@ -46,17 +46,17 @@ Config file: `/src/config.py` To simplify the process and use `--autogenerate` flag, you need to import a new model to `/alembic/env.py` Adding new migration: -``` +```sh alembic revision --autogenerate -m "your-migration-name" ``` Upgrade: -``` +```sh alembic upgrade head ``` Downgrade: -``` +```sh alembic downgrade -{number of migrations} ``` @@ -69,6 +69,7 @@ Available at `/docs` route ### Tests To run tests -``` -docker compose -f docker-compose.test.yml up --build test --attach test --exit-code-from test +```sh +docker compose -p "test" -f docker-compose.test.yml up --build test --attach test --exit-code-from test; \ + docker compose -p "test" -f docker-compose.test.yml down ``` diff --git a/packages/examples/cvat/recording-oracle/dockerfiles/test.Dockerfile b/packages/examples/cvat/recording-oracle/dockerfiles/test.Dockerfile index 591d0cb769..eaa436f5d0 100644 --- a/packages/examples/cvat/recording-oracle/dockerfiles/test.Dockerfile +++ b/packages/examples/cvat/recording-oracle/dockerfiles/test.Dockerfile @@ -4,16 +4,20 @@ WORKDIR /app RUN apt-get update -y && \ apt-get install -y jq ffmpeg libsm6 libxext6 && \ - pip install --no-cache poetry + rm -rf /var/lib/apt/lists/* + +RUN pip install --no-cache poetry COPY pyproject.toml poetry.lock ./ -RUN poetry config virtualenvs.create false \ - && poetry install --no-interaction --no-ansi --no-root \ - && poetry cache clear pypi --all +RUN --mount=type=cache,target=/root/.cache \ + poetry config virtualenvs.create false && \ + poetry install --no-interaction --no-ansi --no-root + +RUN python -m pip uninstall -y poetry pip COPY . . RUN rm -f ./src/.env -CMD ["pytest", "-W", "ignore::DeprecationWarning", "-W", "ignore::RuntimeWarning", "-W", "ignore::UserWarning", "-v"] \ No newline at end of file +CMD ["pytest"] \ No newline at end of file diff --git a/packages/examples/cvat/recording-oracle/pyproject.toml b/packages/examples/cvat/recording-oracle/pyproject.toml index 194543a267..e12f60a4ce 100644 --- a/packages/examples/cvat/recording-oracle/pyproject.toml +++ b/packages/examples/cvat/recording-oracle/pyproject.toml @@ -123,6 +123,7 @@ ignore = [ "ANN001", # | "ANN003", # | "ARG001", # | + "FBT001", # Allow bool-annotated positional args in functions "SLF001", # Allow private attrs access "PLR2004", # Allow magic values "S", # security diff --git a/packages/examples/cvat/recording-oracle/pytest.ini b/packages/examples/cvat/recording-oracle/pytest.ini new file mode 100644 index 0000000000..bdf7142e2b --- /dev/null +++ b/packages/examples/cvat/recording-oracle/pytest.ini @@ -0,0 +1,10 @@ +[pytest] +addopts = --verbose +filterwarnings = + ignore::DeprecationWarning:cvat_sdk.core + ignore::DeprecationWarning:human_protocol_sdk.storage + ignore:Field name \"sort\" shadows:UserWarning:pydantic._internal._fields + +python_files = test_*.py +python_classes = *Test +python_functions = test_* \ No newline at end of file diff --git a/packages/examples/cvat/recording-oracle/src/utils/requests.py b/packages/examples/cvat/recording-oracle/src/utils/requests.py index ef2174f9b9..73519f3947 100644 --- a/packages/examples/cvat/recording-oracle/src/utils/requests.py +++ b/packages/examples/cvat/recording-oracle/src/utils/requests.py @@ -9,8 +9,8 @@ def get_or_404( obj: T | None, object_id: V, - object_type_name: str, *, + object_type_name: str, reason: str | None = None, ) -> T: if obj is None: diff --git a/packages/examples/cvat/recording-oracle/tests/integration/services/test_validation_service.py b/packages/examples/cvat/recording-oracle/tests/integration/services/test_validation_service.py index 1ba1e17052..e901f89c14 100644 --- a/packages/examples/cvat/recording-oracle/tests/integration/services/test_validation_service.py +++ b/packages/examples/cvat/recording-oracle/tests/integration/services/test_validation_service.py @@ -103,7 +103,7 @@ def test_create_and_get_validation_result(self): assert vrs[0] == vr -class TestManifestChange: +class ManifestChangeTest: def test_can_handle_lowered_quality_requirements_in_manifest(self, session: Session): escrow_address = ESCROW_ADDRESS chain_id = Networks.localhost @@ -282,7 +282,7 @@ def test_can_handle_lowered_quality_requirements_in_manifest(self, session: Sess ) -class TestValidationLogic: +class ValidationLogicTest: @pytest.mark.parametrize("seed", range(25)) def test_can_change_bad_honeypots_in_jobs(self, session: Session, seed: int): escrow_address = ESCROW_ADDRESS @@ -1134,7 +1134,7 @@ def patched_get_jobs_quality_reports(task_id: int): mock_update_task_validation_layout.assert_not_called() -class TestAnnotationMerging: +class AnnotationMergingTest: def test_can_prepare_final_results_in_validated_escrow(self, session: Session): escrow_address = ESCROW_ADDRESS chain_id = Networks.localhost.value