diff --git a/.github/workflows/build-and-push.yml b/.github/workflows/build-and-push.yml index 05edd19d..480d29a8 100644 --- a/.github/workflows/build-and-push.yml +++ b/.github/workflows/build-and-push.yml @@ -1,9 +1,9 @@ name: Build and Push Docker Images -on: - push: - branches: - - main +on: push +# push: +# branches: +# - main env: REGISTRY: ghcr.io diff --git a/fia_api/core/exceptions.py b/fia_api/core/exceptions.py index 9dd95ebb..86b0e493 100644 --- a/fia_api/core/exceptions.py +++ b/fia_api/core/exceptions.py @@ -72,3 +72,7 @@ class DataIntegrityError(Exception): class JobOwnerError(DataIntegrityError): """Job has no owner""" + + +class LLSPSubmissionFailureError(Exception): + """LLSP submission failed""" diff --git a/fia_api/core/job_maker.py b/fia_api/core/job_maker.py index eebf9891..6b9862e4 100644 --- a/fia_api/core/job_maker.py +++ b/fia_api/core/job_maker.py @@ -3,16 +3,18 @@ import functools import json import logging +import os from collections.abc import Callable from pathlib import Path -from typing import Any +from typing import Any, Concatenate, ParamSpec, TypeVar, cast +import requests from pika.adapters.blocking_connection import BlockingConnection # type: ignore[import-untyped] from pika.connection import ConnectionParameters # type: ignore[import-untyped] from pika.credentials import PlainCredentials # type: ignore[import-untyped] from sqlalchemy.orm import Session -from fia_api.core.exceptions import JobRequestError +from fia_api.core.exceptions import JobRequestError, LLSPSubmissionFailureError from fia_api.core.models import Job, JobOwner, JobType, Script, State from fia_api.core.repositories import Repo from fia_api.core.specifications.job import JobSpecification @@ -23,7 +25,11 @@ logger = logging.getLogger(__name__) -def require_owner(func: Callable[..., Any]) -> Callable[..., Any]: +P = ParamSpec("P") +R = TypeVar("R") + + +def require_owner(func: Callable[Concatenate[JobMaker, P], R]) -> Callable[Concatenate[JobMaker, P], R]: """ Decorator to ensure that either a user_number or experiment_number is provided to the function. if not, raise a JobRequestError @@ -32,12 +38,12 @@ def require_owner(func: Callable[..., Any]) -> Callable[..., Any]: """ @functools.wraps(func) - def wrapper(self: JobMaker, *args: tuple[Any], **kwargs: dict[str, Any]) -> Any: + def wrapper(self: JobMaker, *args: P.args, **kwargs: P.kwargs) -> R: if kwargs.get("user_number") is None and kwargs.get("experiment_number") is None: raise JobRequestError("Something needs to own the job, either experiment_number or user_number.") return func(self, *args, **kwargs) - return wrapper + return cast("Callable[Concatenate[JobMaker, P], R]", wrapper) class JobMaker: @@ -58,7 +64,6 @@ def __init__( self.queue_name = queue_name self.connection = None self.channel = None - self._connect_to_broker() def _connect_to_broker(self) -> None: """ @@ -195,6 +200,53 @@ def create_simple_job( self._send_message(json.dumps(message_dict)) return job.id + @require_owner + def create_fast_start_job( + self, runner_image: str, script: str, experiment_number: int | None = None, user_number: int | None = None + ) -> int: + """ + Create a fast start job by calling the external LLSP API. + :param runner_image: The image used as a runner + :param script: The script to be used + :param experiment_number: (unused but required by decorator/signature consistency if reused) + :param user_number: the user number of the owner + :return: created job id + """ + + job_owner = self._get_or_create_job_owner(None, user_number) # fast starts are user jobs only + + script_object = self._get_or_create_script(script) + + job = Job( + owner=job_owner, + job_type=JobType.FAST_START, + runner_image=runner_image, + script_id=script_object.id, + state=State.NOT_STARTED, + inputs={}, + ) + + # Call External API + llsp_url = os.environ.get("LLSP_API_HOST", "http://localhost:8001") + llsp_key = os.environ.get("LLSP_API_KEY", "shh") + + try: + response = requests.post( + f"{llsp_url}/execute", + json={"script": script}, + headers={"Authorization": f"Bearer {llsp_key}"}, + timeout=10, + ) + response.raise_for_status() + except requests.RequestException as e: + logger.error(f"Failed to submit fast start job to LLSP: {e}") + job.state = State.UNSUCCESSFUL + job.outputs = "Job failed to submit to LLSP." + raise LLSPSubmissionFailureError(f"Failed to submit fast start job: {e}") from e + finally: + job = self._job_repo.add_one(job) + return job.id + def _get_or_create_script(self, script: str) -> Script: script_hash = hash_script(script) script_object = self._script_repo.find_one(ScriptSpecification().by_script_hash(script_hash)) diff --git a/fia_api/core/models.py b/fia_api/core/models.py index 88fc6cc8..dde5cde5 100644 --- a/fia_api/core/models.py +++ b/fia_api/core/models.py @@ -33,6 +33,7 @@ class JobType(enum.Enum): RERUN = "RERUN" SIMPLE = "SIMPLE" AUTOREDUCTION = "AUTOREDUCTION" + FAST_START = "FAST_START" class Base(DeclarativeBase): diff --git a/fia_api/core/services/job.py b/fia_api/core/services/job.py index 161870f1..250f7f00 100644 --- a/fia_api/core/services/job.py +++ b/fia_api/core/services/job.py @@ -17,6 +17,7 @@ from fia_api.core.repositories import Repo from fia_api.core.request_models import AutoreductionRequest, PartialJobUpdateRequest from fia_api.core.session import get_db_session +from fia_api.core.specifications.base import Specification from fia_api.core.specifications.filters import apply_filters_to_spec from fia_api.core.specifications.instrument import InstrumentSpecification from fia_api.core.specifications.job import JobSpecification @@ -53,6 +54,10 @@ class SimpleJob(BaseModel): script: str +class FastStartJob(BaseModel): + script: str + + class RerunJob(BaseModel): job_id: int runner_image: str @@ -82,6 +87,7 @@ def get_job_by_instrument( order_direction: Literal["asc", "desc"] = "desc", user_number: int | None = None, filters: Mapping[str, Any] | None = None, + include_fast_start_jobs: bool = False, ) -> Sequence[Job]: """ Given an instrument name return a sequence of jobs for that instrument. Optionally providing a limit and @@ -94,6 +100,7 @@ def get_job_by_instrument( :param order_by: (str) Field to order by. :param user_number: (optional[str]) The user number of who is making the request :param filters: Optional Mapping[str,Any] the filters to be applied to the query + :param include_fast_start_jobs: (bool) Whether to include fast start jobs :return: Sequence of Jobs for an instrument """ specification = JobSpecification().by_instruments( @@ -104,10 +111,13 @@ def get_job_by_instrument( order_direction=order_direction, user_number=user_number, ) + spec: Specification[Job] = specification if filters: - specification = apply_filters_to_spec(filters, specification) + spec = apply_filters_to_spec(filters, spec) + if not include_fast_start_jobs: + spec = apply_filters_to_spec({"job_type_not_in": [JobType.FAST_START]}, spec) job_repo: Repo[Job] = Repo(session) - return job_repo.find(specification) + return job_repo.find(spec) def get_all_jobs( @@ -118,6 +128,7 @@ def get_all_jobs( order_direction: Literal["asc", "desc"] = "desc", user_number: int | None = None, filters: Mapping[str, Any] | None = None, + include_fast_start_jobs: bool = False, ) -> Sequence[Job]: """ Get all jobs, if a user number is provided then only the jobs that user has permission for will be @@ -129,6 +140,7 @@ def get_all_jobs( :param order_direction: (str) Direction to der by "asc" | "desc" :param order_by: (str) Field to order by. :param filters: Optional Mapping[str,Any] the filters to be applied + :param include_fast_start_jobs: (bool) Whether to include fast start jobs :return: A Sequence of Jobs """ specification = JobSpecification() @@ -141,10 +153,13 @@ def get_all_jobs( specification = specification.by_experiment_numbers( experiment_numbers, limit=limit, offset=offset, order_by=order_by, order_direction=order_direction ) + spec: Specification[Job] = specification if filters: - apply_filters_to_spec(filters, specification) + spec = apply_filters_to_spec(filters, spec) + if not include_fast_start_jobs: + spec = apply_filters_to_spec({"job_type_not_in": [JobType.FAST_START]}, spec) job_repo: Repo[Job] = Repo(session) - return job_repo.find(specification) + return job_repo.find(spec) def get_job_by_id( @@ -174,16 +189,24 @@ def get_job_by_id( return job -def count_jobs_by_instrument(instrument: str, session: Session, filters: Mapping[str, Any] | None) -> int: +def count_jobs_by_instrument( + instrument: str, + session: Session, + filters: Mapping[str, Any] | None, + include_fast_start_jobs: bool = False, +) -> int: """ Given an instrument name, count the jobs for that instrument :param instrument: Instruments to count from :param session: The current session of the request + :param include_fast_start_jobs: (bool) Whether to include fast start jobs :return: Number of jobs """ - spec = JobSpecification().by_instruments(instruments=[instrument]) + spec: Specification[Job] = JobSpecification().by_instruments(instruments=[instrument]) if filters: spec = apply_filters_to_spec(filters, spec) + if not include_fast_start_jobs: + spec = apply_filters_to_spec({"job_type_not_in": [JobType.FAST_START]}, spec) job_repo: Repo[Job] = Repo(session) return job_repo.count(spec) @@ -191,16 +214,20 @@ def count_jobs_by_instrument(instrument: str, session: Session, filters: Mapping def count_jobs( session: Session, filters: Mapping[str, Any] | None = None, + include_fast_start_jobs: bool = False, ) -> int: """ Count the total number of jobs :param filters: Optional Mapping[str,Any] the filters to be applied :param session: The current session of the request + :param include_fast_start_jobs: (bool) Whether to include fast start jobs :return: (int) number of jobs """ - spec = JobSpecification().all() + spec: Specification[Job] = JobSpecification().all() if filters: spec = apply_filters_to_spec(filters, spec) + if not include_fast_start_jobs: + spec = apply_filters_to_spec({"job_type_not_in": [JobType.FAST_START]}, spec) job_repo: Repo[Job] = Repo(session) return job_repo.count(spec) diff --git a/fia_api/core/specifications/filters.py b/fia_api/core/specifications/filters.py index 1ab584b8..893e559a 100644 --- a/fia_api/core/specifications/filters.py +++ b/fia_api/core/specifications/filters.py @@ -11,7 +11,7 @@ from typing import Any from fia_api.core.exceptions import BadRequestError -from fia_api.core.models import Instrument, Job, JobOwner, Run +from fia_api.core.models import Instrument, Job, JobOwner, JobType, Run from fia_api.core.specifications.base import Specification, T logger = logging.getLogger(__name__) @@ -50,6 +50,24 @@ def apply(self, specification: Specification[T]) -> Specification[T]: return specification +class JobTypeInFilter(Filter): + """Filter implementation that checks if job types are included in the query.""" + + def apply(self, specification: Specification[T]) -> Specification[T]: + job_types = [JobType(val) for val in self.value] + specification.value = specification.value.where(Job.job_type.in_(job_types)) + return specification + + +class JobTypeNotInFilter(Filter): + """Filter implementation that checks if job types are NOT included in the query.""" + + def apply(self, specification: Specification[T]) -> Specification[T]: + job_types = [JobType(val) for val in self.value] + specification.value = specification.value.where(Job.job_type.notin_(job_types)) + return specification + + class JobStateFilter(Filter): """Filter implementation that checks if job states match the specified value in the query.""" @@ -170,6 +188,10 @@ def get_filter(key: str, value: Any) -> Filter: # noqa: C901, PLR0911, PLR0912 return JobStateFilter(value) case "experiment_number_in": return ExperimentNumberInFilter(value) + case "job_type_in": + return JobTypeInFilter(value) + case "job_type_not_in": + return JobTypeNotInFilter(value) case "title": return TitleFilter(value) case "filename": diff --git a/fia_api/exception_handlers.py b/fia_api/exception_handlers.py index 0392ddac..3c430d61 100644 --- a/fia_api/exception_handlers.py +++ b/fia_api/exception_handlers.py @@ -37,6 +37,19 @@ async def bad_job_request_handler(_: Request, __: Exception) -> JSONResponse: ) +async def llsp_api_request_handler(_: Request, __: Exception) -> JSONResponse: + """ + Automatically return a 424 when a job submission to LLSP-API fails + :param _: + :param __: + :return: JSONResponse with 424 + """ + return JSONResponse( + status_code=HTTPStatus.FAILED_DEPENDENCY, + content={"message": "Job Failed to submit to LLSP-API"}, + ) + + async def missing_script_handler(_: Request, __: Exception) -> JSONResponse: """ Automatically return a 404 when the script could not be found locally or remote diff --git a/fia_api/fia_api.py b/fia_api/fia_api.py index edbb6588..0dc0cbdf 100644 --- a/fia_api/fia_api.py +++ b/fia_api/fia_api.py @@ -15,6 +15,7 @@ GithubAPIRequestError, JobOwnerError, JobRequestError, + LLSPSubmissionFailureError, MissingRecordError, MissingScriptError, NoFilesAddedError, @@ -29,6 +30,7 @@ data_integrity_handler, github_api_request_handler, job_owner_err_handler, + llsp_api_request_handler, missing_record_handler, missing_script_handler, no_files_added_handler, @@ -98,3 +100,4 @@ def filter(self, record: logging.LogRecord) -> bool: app.add_exception_handler(BadRequestError, bad_request_handler) app.add_exception_handler(DataIntegrityError, data_integrity_handler) app.add_exception_handler(JobOwnerError, job_owner_err_handler) +app.add_exception_handler(LLSPSubmissionFailureError, llsp_api_request_handler) diff --git a/fia_api/routers/job_creation.py b/fia_api/routers/job_creation.py index d2d127c1..64d37d1a 100644 --- a/fia_api/routers/job_creation.py +++ b/fia_api/routers/job_creation.py @@ -9,6 +9,7 @@ from fia_api.core.exceptions import AuthError from fia_api.core.job_maker import JobMaker from fia_api.core.services.job import ( + FastStartJob, RerunJob, SimpleJob, get_experiment_number_for_job_id, @@ -44,7 +45,7 @@ async def make_rerun_job( if experiment_number not in experiment_numbers: # If not staff this is not allowed raise AuthError("User not authorised for this action") - return job_maker.create_rerun_job( # type: ignore # Despite returning int, mypy believes this returns any + return job_maker.create_rerun_job( job_id=rerun_job.job_id, runner_image=rerun_job.runner_image, script=rerun_job.script, @@ -70,11 +71,32 @@ async def make_simple_job( if user.role != "staff": # If not staff this is not allowed raise AuthError("User not authorised for this action") - return job_maker.create_simple_job( # type: ignore # Despite returning int, mypy believes this returns any + return job_maker.create_simple_job( runner_image=simple_job.runner_image, script=simple_job.script, user_number=user.user_number ) +@JobCreationRouter.post("/job/fast-start", tags=["job creation"]) +async def create_fast_start_job_endpoint( + fast_start_job: FastStartJob, + credentials: Annotated[HTTPAuthorizationCredentials, Depends(jwt_api_security)], + job_maker: Annotated[JobMaker, Depends(job_maker)], +) -> int: + """ + Create a fast start job, returning the ID of the created job. + \f + :param fast_start_job: The fast start job details including runner image and script. + :param credentials: HTTPAuthorizationCredentials + :param job_maker: Dependency injected job maker + :return: The job id + """ + user = get_user_from_token(credentials.credentials) + # Any authenticated user can create a fast start job + return job_maker.create_fast_start_job( + runner_image="default", script=fast_start_job.script, user_number=user.user_number + ) + + @JobCreationRouter.get("/jobs/runners", tags=["job creation"]) async def get_mantid_runners( credentials: Annotated[HTTPAuthorizationCredentials, Depends(jwt_api_security)], diff --git a/fia_api/routers/jobs.py b/fia_api/routers/jobs.py index c8dfff0a..ff8d1481 100644 --- a/fia_api/routers/jobs.py +++ b/fia_api/routers/jobs.py @@ -18,6 +18,7 @@ AuthError, NoFilesAddedError, ) +from fia_api.core.models import JobType from fia_api.core.request_models import AutoreductionRequest, PartialJobUpdateRequest from fia_api.core.responses import AutoreductionResponse, CountResponse, JobResponse, JobWithRunResponse from fia_api.core.services.job import ( @@ -68,6 +69,7 @@ async def get_jobs( include_run: bool = False, filters: Annotated[str | None, Query(description="json string of filters")] = None, as_user: bool = False, + include_fast_start_jobs: bool = False, ) -> list[JobResponse] | list[JobWithRunResponse]: """ Retrieve all jobs. @@ -83,6 +85,7 @@ async def get_jobs( :param include_run: bool :param filters: json string of filters :param as_user: bool + :param include_fast_start_jobs: bool :return: List of JobResponse objects """ user = get_user_from_token(credentials.credentials) @@ -121,6 +124,7 @@ async def get_jobs( order_direction=order_direction, user_number=user_number, filters=filters, + include_fast_start_jobs=include_fast_start_jobs, ) if include_run: @@ -135,7 +139,7 @@ async def get_jobs( @JobsRouter.get("/instrument/{instrument}/jobs", tags=["jobs"]) -async def get_jobs_by_instrument( +async def get_jobs_by_instrument( # noqa: PLR0913 instrument: str, credentials: Annotated[HTTPAuthorizationCredentials, Depends(jwt_api_security)], session: Annotated[Session, Depends(get_db_session)], @@ -146,6 +150,7 @@ async def get_jobs_by_instrument( include_run: bool = False, filters: Annotated[str | None, Query(description="json string of filters")] = None, as_user: bool = False, + include_fast_start_jobs: bool = False, ) -> list[JobResponse] | list[JobWithRunResponse]: """ Retrieve a list of jobs for a given instrument. @@ -162,6 +167,7 @@ async def get_jobs_by_instrument( :param include_run: bool :param filters: json string of filters :param as_user: bool + :param include_fast_start_jobs: bool :return: List of JobResponse objects """ user = get_user_from_token(credentials.credentials) @@ -204,6 +210,7 @@ async def get_jobs_by_instrument( order_direction=order_direction, user_number=user_number, filters=filters, + include_fast_start_jobs=include_fast_start_jobs, ) if include_run: @@ -222,12 +229,14 @@ async def count_jobs_for_instrument( instrument: str, session: Annotated[Session, Depends(get_db_session)], filters: Annotated[str | None, Query(description="json string of filters")] = None, + include_fast_start_jobs: bool = False, ) -> CountResponse: """ Count jobs for a given instrument. \f :param instrument: the name of the instrument :param filters: json string of filters + :param include_fast_start_jobs: bool :return: CountResponse containing the count """ instrument = instrument.upper() @@ -240,7 +249,9 @@ async def count_jobs_for_instrument( if isinstance(cached, dict) and "count" in cached: return CountResponse.model_validate(cached) - count = count_jobs_by_instrument(instrument, session, filters=parsed_filters) + count = count_jobs_by_instrument( + instrument, session, filters=parsed_filters, include_fast_start_jobs=include_fast_start_jobs + ) payload = {"count": count} if cache_key: cache_set_json(cache_key, payload, JOB_COUNT_CACHE_TTL_SECONDS) @@ -287,8 +298,22 @@ async def update_job( :return: JobResponse """ user = get_user_from_token(credentials.credentials) - if user.role != "staff": + + # Check permissions + is_staff = user.role == "staff" + + # We need to get the job to check its type for API key access + job_in_db = get_job_by_id(job_id, session) + + if is_staff: + # Staff can update any job + pass + elif user.user_number == -1 and job_in_db.job_type == JobType.FAST_START: + # API Key users can update FAST_START jobs + pass + else: raise AuthError("User not authorised for this action") + return JobResponse.from_job(update_job_by_id(job_id, job, session)) @@ -296,8 +321,11 @@ async def update_job( async def count_all_jobs( session: Annotated[Session, Depends(get_db_session)], filters: Annotated[str | None, Query(description="json string of filters")] = None, + include_fast_start_jobs: bool = False, ) -> CountResponse: - """Count all jobs \f :param filters: json string of filters :return: + """Count all jobs + \f + :param filters: json string of filters :return: CountResponse containing the count.""" parsed_filters = json.loads(filters) if filters else None @@ -308,7 +336,7 @@ async def count_all_jobs( if isinstance(cached, dict) and "count" in cached: return CountResponse.model_validate(cached) - count = count_jobs(session, filters=parsed_filters) + count = count_jobs(session, filters=parsed_filters, include_fast_start_jobs=include_fast_start_jobs) payload = {"count": count} if cache_key: cache_set_json(cache_key, payload, JOB_COUNT_CACHE_TTL_SECONDS) diff --git a/test/core/test_job_maker.py b/test/core/test_job_maker.py index e0d53668..a175ffbf 100644 --- a/test/core/test_job_maker.py +++ b/test/core/test_job_maker.py @@ -3,6 +3,7 @@ from unittest import mock import pytest # type: ignore +from requests.exceptions import RequestException from fia_api.core.exceptions import JobRequestError from fia_api.core.job_maker import JobMaker @@ -17,7 +18,7 @@ def test_send_message(broker): job_maker._send_message(custom_message) - assert broker.call_count == 2 # noqa: PLR2004 + assert broker.call_count == 1 assert broker.call_args == [mock.call(), mock.call()] assert job_maker.channel.basic_publish.call_count == 1 assert job_maker.channel.basic_publish.call_args == mock.call(exchange="", routing_key="", body=custom_message) @@ -184,3 +185,50 @@ def test_create_simple_job_require_owner(): script = "print('error')" with pytest.raises(JobRequestError): job_maker.create_simple_job(runner_image=runner_image, script=script, user_number=None, experiment_number=None) + + +@mock.patch("fia_api.core.job_maker.JobMaker._connect_to_broker") +@mock.patch("fia_api.core.job_maker.requests.post") +def test_create_fast_start_job_success(mock_post, mock_connect, faker): + mock_session = mock.Mock() + job_maker = JobMaker("", "", "", "test_queue", mock_session) + job_maker._owner_repo.find_one = mock.MagicMock(return_value=None) + job_maker._script_repo.find_one = mock.MagicMock(return_value=None) + job = mock.MagicMock() + job.id = faker.random.randint(1000, 2000) + job_maker._job_repo.add_one = mock.MagicMock(return_value=job) + + runner_image = "fast_runner" + script = "print('fast')" + user_number = 12345 + + mock_response = mock.Mock() + mock_response.status_code = 200 + mock_post.return_value = mock_response + + job_id = job_maker.create_fast_start_job(runner_image=runner_image, script=script, user_number=user_number) + + assert job_id == job.id + mock_post.assert_called_once() + _, kwargs = mock_post.call_args + assert kwargs["json"] == {"script": script} + # Check default header key 'shh' + assert kwargs["headers"]["Authorization"] == "Bearer shh" + + +@mock.patch("fia_api.core.job_maker.JobMaker._connect_to_broker") +@mock.patch("fia_api.core.job_maker.requests.post") +def test_create_fast_start_job_failure(mock_post, mock_connect, faker): + mock_session = mock.Mock() + job_maker = JobMaker("", "", "", "test_queue", mock_session) + # job setup ... + job = mock.MagicMock() + job.id = 1 + job_maker._job_repo.add_one = mock.MagicMock(return_value=job) + job_maker._owner_repo.find_one = mock.MagicMock(return_value=None) + job_maker._script_repo.find_one = mock.MagicMock(return_value=None) + + mock_post.side_effect = RequestException("Connection error") + + with pytest.raises(JobRequestError, match="Failed to submit fast start job"): + job_maker.create_fast_start_job(runner_image="img", script="print('fail')", user_number=123) diff --git a/test/e2e/test_core.py b/test/e2e/test_core.py index 55858308..640e713b 100644 --- a/test/e2e/test_core.py +++ b/test/e2e/test_core.py @@ -7,7 +7,7 @@ from sqlalchemy import func, select from starlette.testclient import TestClient -from fia_api.core.models import Instrument, Job, JobOwner, Run +from fia_api.core.models import Instrument, Job, JobOwner, JobType, Run from fia_api.fia_api import app from utils.db_generator import SESSION @@ -60,7 +60,13 @@ def test_count_jobs_with_filters(mock_post): """Test count with filter""" expected_count = 0 with SESSION() as session: - expected_count = session.scalar(select(func.count()).select_from(Job).join(Run).where(Run.title.icontains("n"))) + expected_count = session.scalar( + select(func.count()) + .select_from(Job) + .join(Run) + .where(Run.title.icontains("n")) + .where(Job.job_type != JobType.FAST_START) + ) mock_post.return_value.status_code = HTTPStatus.OK response = client.get('/jobs/count?filters={"title":"n"}') assert response.json()["count"] == expected_count @@ -78,6 +84,7 @@ def test_count_jobs_by_instrument_with_filter(mock_post): .join(Instrument) .where(Run.title.icontains("n")) .where(Instrument.instrument_name == "MARI") + .where(Job.job_type != JobType.FAST_START) ) mock_post.return_value.status_code = HTTPStatus.OK @@ -444,9 +451,11 @@ def test_jobs_count(): Test count endpoint for all jobs :return: """ + with SESSION() as session: + expected_count = session.query(Job).where(Job.job_type != JobType.FAST_START).count() response = client.get("/jobs/count") assert response.status_code == HTTPStatus.OK - assert response.json()["count"] == 5001 # noqa: PLR2004 + assert response.json()["count"] == expected_count @patch("fia_api.core.auth.tokens.requests.post") diff --git a/test/utils.py b/test/utils.py index df53dc4e..71d14366 100644 --- a/test/utils.py +++ b/test/utils.py @@ -134,9 +134,13 @@ def job(self, instrument: Instrument, faker: Faker) -> Job: ) job.state = state job.stacktrace = "some stacktrace" - job.owner = JobOwner(experiment_number=faker.unique.pyint(min_value=10000, max_value=999999)) + job_type = faker.enum(JobType) + if job_type == JobType.FAST_START: + job.owner = JobOwner(user_number=faker.unique.pyint(min_value=10000, max_value=999999)) + else: + job.owner = JobOwner(experiment_number=faker.unique.pyint(min_value=10000, max_value=999999)) job.instrument = instrument - job.job_type = faker.enum(JobType) + job.job_type = job_type return job def script(self, faker: Faker) -> Script: