diff --git a/desdeo/api/models/__init__.py b/desdeo/api/models/__init__.py index f7b1fda1c..ce70820d0 100644 --- a/desdeo/api/models/__init__.py +++ b/desdeo/api/models/__init__.py @@ -23,7 +23,6 @@ "ExtraFunctionDB", "ForestProblemMetaData", "GenericIntermediateSolutionResponse", - "GetSessionRequest", "GNIMBUSOptimizationState", "GNIMBUSVotingState", "GNIMBUSEndState", @@ -56,7 +55,6 @@ "PreferredRanges", "ProblemDB", "ProblemAddFromJSONRequest", - "ProblemGetRequest", "ProblemInfo", "ProblemInfoSmall", "ProblemMetaDataDB", @@ -237,7 +235,6 @@ ObjectiveDB, ProblemAddFromJSONRequest, ProblemDB, - ProblemGetRequest, ProblemInfo, ProblemInfoSmall, ProblemMetaDataDB, @@ -254,7 +251,6 @@ from .reference_point_method import RPMSolveRequest from .session import ( CreateSessionRequest, - GetSessionRequest, InteractiveSessionBase, InteractiveSessionDB, InteractiveSessionInfo, diff --git a/desdeo/api/models/problem.py b/desdeo/api/models/problem.py index 4ba73dda7..fefd6c1a1 100644 --- a/desdeo/api/models/problem.py +++ b/desdeo/api/models/problem.py @@ -52,12 +52,6 @@ class ProblemBase(SQLModel): variable_domain: VariableDomainTypeEnum | None = Field() -class ProblemGetRequest(SQLModel): - """Model to deal with problem fetching requests.""" - - problem_id: int - - class ProblemSelectSolverRequest(SQLModel): """Model to request a specific solver for a problem.""" diff --git a/desdeo/api/models/representative_solution.py b/desdeo/api/models/representative_solution.py index 8b643768c..93077c774 100644 --- a/desdeo/api/models/representative_solution.py +++ b/desdeo/api/models/representative_solution.py @@ -10,11 +10,6 @@ class RepresentativeSolutionSetBase(SQLModel): ideal: dict[str, float] nadir: dict[str, float] -class RepresentativeSolutionSetRequest(RepresentativeSolutionSetBase): - """Model of the request to the representative solution set.""" - - problem_id: int - class RepresentativeSolutionSetInfo(SQLModel): """Model of the representative solution set info.""" @@ -25,8 +20,6 @@ class RepresentativeSolutionSetInfo(SQLModel): ideal: dict[str, float] nadir: dict[str, float] -class RepresentativeSolutionSetFull( - RepresentativeSolutionSetInfo -): +class RepresentativeSolutionSetFull(RepresentativeSolutionSetInfo): """Model of the representative solution set full info.""" solution_data: dict[str, list[float]] diff --git a/desdeo/api/models/session.py b/desdeo/api/models/session.py index 5b0907ec2..f8386baad 100644 --- a/desdeo/api/models/session.py +++ b/desdeo/api/models/session.py @@ -15,12 +15,6 @@ class CreateSessionRequest(SQLModel): info: str | None = Field(default=None) -class GetSessionRequest(SQLModel): - """Model of the request to get a specific session.""" - - session_id: int = Field() - - class InteractiveSessionBase(SQLModel): """The base model for representing interactive sessions.""" diff --git a/desdeo/api/routers/problem.py b/desdeo/api/routers/problem.py index 3b8a297b5..94bbfb782 100644 --- a/desdeo/api/routers/problem.py +++ b/desdeo/api/routers/problem.py @@ -10,7 +10,6 @@ from desdeo.api.models import ( ForestProblemMetaData, ProblemDB, - ProblemGetRequest, ProblemInfo, ProblemInfoSmall, ProblemMetaDataDB, @@ -22,9 +21,9 @@ UserRole, ) from desdeo.api.models.representative_solution import ( + RepresentativeSolutionSetBase, RepresentativeSolutionSetFull, RepresentativeSolutionSetInfo, - RepresentativeSolutionSetRequest, ) from desdeo.api.routers.user_authentication import get_current_user from desdeo.problem import Problem @@ -89,27 +88,28 @@ def get_problems_info(user: Annotated[User, Depends(get_current_user)]) -> list[ """ return user.problems - -@router.post("/get") +@router.get("/{problem_id}") def get_problem( - request: ProblemGetRequest, - context: Annotated[SessionContext, Depends(SessionContextGuard(require=[ContextField.PROBLEM]))], + problem_id: int, + context: Annotated[ + SessionContext, + Depends(SessionContextGuard(require=[ContextField.PROBLEM])), + ], ) -> ProblemInfo: - """Get the model of a specific problem. + """Get a specific problem by id. Args: - request (ProblemGetRequest): the request containing the problem's id `problem_id`. - context (Annotated[SessionContext, Depends): the session context. + problem_id (int): problem id. + context (Annotated[SessionContext, Depends): the session context. Raises: - HTTPException: could not find a problem with the given id. + HTTPException: could not find a problem with the given id. Returns: - ProblemInfo: detailed information on the requested problem. + ProblemInfo: detailed information on the requested problem. """ return context.problem_db - @router.post("/add") def add_problem( request: Annotated[Problem, Depends(parse_problem_json)], @@ -299,15 +299,17 @@ def select_solver( return JSONResponse(content={"message": "OK"}, status_code=status.HTTP_200_OK) -@router.post("/add_representative_solution_set") +@router.post("/{problem_id}/add_representative_solution_set") def add_representative_solution_set( - request: RepresentativeSolutionSetRequest, + problem_id: int, + request: RepresentativeSolutionSetBase, context: Annotated[SessionContext, Depends(SessionContextGuard(require=[ContextField.PROBLEM]))], ): """Add a new representative solution set as metadata to a problem. Args: - request (RepresentativeSolutionSetRequest): The JSON body containing the + problem_id: int, + request (RepresentativeSolutionSetBase): The JSON body containing the details of the representative solution set (name, description, solution data, ideal, nadir). context (SessionContext): The session context providing the current user and database session. @@ -349,8 +351,7 @@ def add_representative_solution_set( nadir=repr_metadata.nadir, ) - -@router.get("/all_representative_solution_sets/{problem_id}") +@router.get("/{problem_id}/all_representative_solution_sets") def get_all_representative_solution_sets( problem_id: int, context: Annotated[SessionContext, Depends(SessionContextGuard(require=[]))], diff --git a/desdeo/api/routers/session.py b/desdeo/api/routers/session.py index a5addcd81..d36c6e3ff 100644 --- a/desdeo/api/routers/session.py +++ b/desdeo/api/routers/session.py @@ -8,7 +8,6 @@ from desdeo.api.db import get_session as get_db_session from desdeo.api.models import ( CreateSessionRequest, - GetSessionRequest, InteractiveSessionDB, InteractiveSessionInfo, User, @@ -51,10 +50,9 @@ def get_session( session: Annotated[Session, Depends(get_db_session)], ) -> InteractiveSessionInfo: """Return an interactive session with a current user.""" - request = GetSessionRequest(session_id=session_id) return fetch_interactive_session( user=user, - request=request, + session_id=session_id, session=session, ) @@ -84,11 +82,9 @@ def delete_session( session: Annotated[Session, Depends(get_db_session)], ) -> None: """Delete an interactive session and all its related states.""" - request = GetSessionRequest(session_id=session_id) - interactive_session = fetch_interactive_session( user=user, - request=request, + session_id=session_id, session=session, ) # raises 404 if not found diff --git a/desdeo/api/routers/utils.py b/desdeo/api/routers/utils.py index 86b52a65e..557042d75 100644 --- a/desdeo/api/routers/utils.py +++ b/desdeo/api/routers/utils.py @@ -20,14 +20,18 @@ StateDB, User, ) -from desdeo.api.models.representative_solution import RepresentativeSolutionSetRequest from desdeo.api.models.session import CreateSessionRequest from desdeo.api.routers.user_authentication import get_current_user -RequestType = RPMSolveRequest | ENautilusStepRequest | RepresentativeSolutionSetRequest | CreateSessionRequest +RequestType = RPMSolveRequest | ENautilusStepRequest | CreateSessionRequest -def fetch_interactive_session(user: User, request: RequestType, session: Session) -> InteractiveSessionDB | None: +def fetch_interactive_session( + user: User, + session: Session, + request: RequestType | None = None, + session_id: int | None = None, + ) -> InteractiveSessionDB | None: """Gets the desired instance of `InteractiveSessionDB`. Args: @@ -35,6 +39,7 @@ def fetch_interactive_session(user: User, request: RequestType, session: Session request (RequestType): the request with possibly information on which interactive session to query. session (Session): the database session (not to be confused with the interactive session) from which the interactive session should be queried. + session_id (int): the id of a session Note: If no explicit `session_id` is given in `request`, this function will try to fetch the @@ -47,23 +52,28 @@ def fetch_interactive_session(user: User, request: RequestType, session: Session Returns: InteractiveSessionDB | None: an interactive session DB model, or nothing. """ - if request.session_id is not None: + # session_id param has highest priority + actual_session_id = session_id or (getattr(request, "session_id", None) if request else None) + + + if actual_session_id is not None: # specific interactive session id is given, try using that - statement = select(InteractiveSessionDB).where(InteractiveSessionDB.id == request.session_id) + statement = select(InteractiveSessionDB).where(InteractiveSessionDB.id == actual_session_id) interactive_session = session.exec(statement).first() if interactive_session is None: # Raise if explicitly requested interactive session cannot be found raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, - detail=f"Could not find interactive session with id={request.session_id}.", + detail=f"Could not find interactive session with id={actual_session_id}.", ) else: - # request.session_id is None + if user.active_session_id is None: + return None + # actual_session_id is None # try to use active session instead statement = select(InteractiveSessionDB).where(InteractiveSessionDB.id == user.active_session_id) - interactive_session = session.exec(statement).first() # At this point interactive_session is either an instance of InteractiveSessionDB or None (which is fine) @@ -184,6 +194,7 @@ def __call__( user: Annotated[User, Depends(get_current_user)], db_session: Annotated[Session, Depends(get_session)], request: RequestType | None = None, + problem_id: int | None = None, ) -> SessionContext: """Call method for the SessionContextGuard class. @@ -192,6 +203,7 @@ def __call__( db_session (Annotated[Session, Depends): the current database session (dep). request (RequestType | None, optional): request based on which the context is fetched. Defaults to None. + problem_id (int): ID of the problem. Returns: SessionContext: the session context with the required fields specified in `self.require`. @@ -205,6 +217,14 @@ def __call__( if hasattr(request, "problem_id"): problem_db = fetch_user_problem(user, request, db_session) + if problem_db is None and problem_id is not None: + class _ProblemOnly: + def __init__(self, problem_id: int): + self.problem_id = problem_id + self.session_id = None + self.parent_state_id = None + problem_db = fetch_user_problem(user, _ProblemOnly(problem_id), db_session) + if hasattr(request, "interactive_session_id") or hasattr(request, "problem_id"): interactive_session = fetch_interactive_session(user, request, db_session) @@ -215,6 +235,17 @@ def __call__( db_session, interactive_session=interactive_session, ) + elif problem_id is not None: + # Build a minimal fake request-like object + class _ProblemOnly: + def __init__(self, problem_id: int): + self.problem_id = problem_id + self.session_id = None + self.parent_state_id = None + + pseudo_request = _ProblemOnly(problem_id) + + problem_db = fetch_user_problem(user, pseudo_request, db_session) context = SessionContext( user=user, diff --git a/desdeo/api/tests/test_problem_metadata.py b/desdeo/api/tests/test_problem_metadata.py index 3d623e250..f53def157 100644 --- a/desdeo/api/tests/test_problem_metadata.py +++ b/desdeo/api/tests/test_problem_metadata.py @@ -1,8 +1,11 @@ +from types import SimpleNamespace + from fastapi.testclient import TestClient # noqa: D100 from sqlmodel import select from desdeo.api.models import ProblemDB, ProblemMetaDataDB, RepresentativeNonDominatedSolutions -from desdeo.api.models.representative_solution import RepresentativeSolutionSetRequest +from desdeo.api.models.representative_solution import RepresentativeSolutionSetBase +from desdeo.api.routers.utils import SessionContextGuard from desdeo.problem.testproblems import dtlz2 from .conftest import login @@ -12,7 +15,6 @@ def test_add_representative_solution_set(client: TestClient, session_and_user: d """Test that the representative solution set can be added via the endpoint.""" session = session_and_user["session"] user = session_and_user["user"] - access_token = login(client) # Create a test problem @@ -21,8 +23,19 @@ def test_add_representative_solution_set(client: TestClient, session_and_user: d session.commit() session.refresh(problem) - solution_set_model = RepresentativeSolutionSetRequest( - problem_id=problem.id, + def test_guard(user, db_session, request=None, problem_id=None): + # problem_id comes from the URL + return SimpleNamespace( + user=user, + db_session=db_session, + problem_db=problem, + interactive_session=None, + parent_state=None, + ) + + client.app.dependency_overrides[SessionContextGuard] = test_guard + + solution_set_model = RepresentativeSolutionSetBase( name="Test solutions", description="Solutions for testing", solution_data={ @@ -38,11 +51,14 @@ def test_add_representative_solution_set(client: TestClient, session_and_user: d ) response = client.post( - "/problem/add_representative_solution_set", + f"/problem/{problem.id}/add_representative_solution_set", headers={"Authorization": f"Bearer {access_token}"}, json=solution_set_model.model_dump(), # send as JSON ) + # Clean up override + client.app.dependency_overrides = {} + assert response.status_code == 200 data = response.json() assert data["name"] == solution_set_model.name @@ -95,7 +111,6 @@ def test_get_all_representative_solution_sets(client: TestClient, session_and_us ) # Attach the representative set - solution_set.metadata_id = problem_metadata.id solution_set.metadata_instance = problem_metadata session.add(solution_set) @@ -106,7 +121,7 @@ def test_get_all_representative_solution_sets(client: TestClient, session_and_us # Call GET endpoint response = client.get( - f"/problem/all_representative_solution_sets/{problem.id}", + f"/problem/{problem.id}/all_representative_solution_sets", headers={"Authorization": f"Bearer {access_token}"} ) @@ -137,7 +152,6 @@ def test_get_representative_solution_set(client: TestClient, session_and_user: d # Add a representative solution set solution_set_payload = { - "problem_id": problem.id, "name": "Full Test Solution Set", "description": "Full info for testing", "solution_data": { @@ -153,7 +167,7 @@ def test_get_representative_solution_set(client: TestClient, session_and_user: d } post_response = client.post( - "/problem/add_representative_solution_set", + f"/problem/{problem.id}/add_representative_solution_set", headers={"Authorization": f"Bearer {access_token}"}, json=solution_set_payload, ) diff --git a/desdeo/api/tests/test_routes.py b/desdeo/api/tests/test_routes.py index edc1f3ebb..8c8095227 100644 --- a/desdeo/api/tests/test_routes.py +++ b/desdeo/api/tests/test_routes.py @@ -33,7 +33,6 @@ NIMBUSSaveRequest, NIMBUSSaveResponse, ProblemDB, - ProblemGetRequest, ProblemInfo, ProblemSelectSolverRequest, ReferencePoint, @@ -119,22 +118,26 @@ def test_get_problem(client: TestClient): """Test fetching specific problems based on their id.""" access_token = login(client) - response = post_json(client, "/problem/get", ProblemGetRequest(problem_id=1).model_dump(), access_token) + response = client.get( + "/problem/1", + headers={"Authorization": f"Bearer {access_token}"} + ) assert response.status_code == 200 info = ProblemInfo.model_validate(response.json()) - assert info.id == 1 assert info.name == "dtlz2" assert info.problem_metadata is None - response = post_json(client, "problem/get", ProblemGetRequest(problem_id=2).model_dump(), access_token) + response = client.get( + "/problem/2", + headers={"Authorization": f"Bearer {access_token}"} + ) assert response.status_code == 200 info = ProblemInfo.model_validate(response.json()) - assert info.id == 2 assert info.name == "The river pollution problem" assert isinstance(info.problem_metadata.forest_metadata[0], ForestProblemMetaData)