diff --git a/desdeo/api/routers/emo.py b/desdeo/api/routers/emo.py index c9702cc0b..d8fb866b3 100644 --- a/desdeo/api/routers/emo.py +++ b/desdeo/api/routers/emo.py @@ -32,7 +32,7 @@ from desdeo.problem import Problem from desdeo.tools.score_bands import SCOREBandsConfig, score_json -from .utils import SessionContext, get_session_context +from .utils import ContextField, SessionContext, SessionContextGuard router = APIRouter(prefix="/method/emo", tags=["EMO"]) @@ -146,7 +146,10 @@ def get_templates() -> list[TemplateOptions]: @router.post("/iterate") def iterate( request: EMOIterateRequest, - context: Annotated[SessionContext, Depends(get_session_context)], + context: Annotated[ + SessionContext, + Depends(SessionContextGuard(require=[ContextField.PROBLEM])) + ], ) -> EMOIterateResponse: """Fetches results from a completed EMO method. @@ -159,16 +162,12 @@ def iterate( """ # Get context objects db_session = context.db_session - interactive_session = context.interactive_session - parent_state = context.parent_state - - # Ensure problem exists - if context.problem_db is None: - raise HTTPException(status_code=404, detail="Problem not found") - problem_db = context.problem_db problem = Problem.from_problemdb(problem_db) + interactive_session = context.interactive_session + parent_state = context.parent_state + # Templates templates = request.template_options or get_templates() @@ -218,7 +217,6 @@ def iterate( return EMOIterateResponse(method_ids=web_socket_ids, client_id=client_id, state_id=state_id) - def _spawn_emo_process( problem: Problem, templates: list[TemplateOptions], @@ -363,7 +361,10 @@ async def _ea_async( @router.post("/fetch") async def fetch_results( request: EMOFetchRequest, - context: Annotated[SessionContext, Depends(get_session_context)], + context: Annotated[ + SessionContext, + Depends(SessionContextGuard(require=[ContextField.PARENT_STATE])) + ], ) -> StreamingResponse: """Fetches results from a completed EMO method. @@ -378,9 +379,6 @@ async def fetch_results( # Use context instead of manual fetch state = context.parent_state - if state is None: - raise HTTPException(status_code=404, detail="Parent state not found.") - if not isinstance(state.state, EMOIterateState): raise TypeError(f"State with id={request.parent_state_id} is not of type EMOIterateState.") @@ -406,11 +404,13 @@ def result_stream(): return StreamingResponse(result_stream()) - @router.post("/fetch_score") async def fetch_score_bands( request: EMOScoreRequest, - context: Annotated[SessionContext, Depends(get_session_context)], + context: Annotated[ + SessionContext, + Depends(SessionContextGuard(require=[ContextField.PROBLEM, ContextField.PARENT_STATE])) + ], ) -> EMOScoreResponse: """Fetches results from a completed EMO method. @@ -429,9 +429,6 @@ async def fetch_score_bands( db_session = context.db_session problem_db = context.problem_db - if parent_state is None: - raise HTTPException(status_code=404, detail="Parent state not found.") - if not isinstance(parent_state.state, EMOIterateState): raise TypeError(f"State with id={request.parent_state_id} is not of type EMOIterateState.") diff --git a/desdeo/api/routers/enautilus.py b/desdeo/api/routers/enautilus.py index 2a1a89fc6..fab033cf8 100644 --- a/desdeo/api/routers/enautilus.py +++ b/desdeo/api/routers/enautilus.py @@ -28,10 +28,7 @@ from desdeo.mcdm import ENautilusResult, enautilus_get_representative_solutions, enautilus_step from desdeo.problem import Problem -from .utils import ( - SessionContext, - get_session_context, -) +from .utils import ContextField, SessionContext, SessionContextGuard router = APIRouter(prefix="/method/enautilus") @@ -39,17 +36,14 @@ @router.post("/step") def step( request: ENautilusStepRequest, - context: Annotated[SessionContext, Depends(get_session_context)], + context: Annotated[SessionContext, Depends(SessionContextGuard(require=[ContextField.PROBLEM]))], ) -> ENautilusStepResponse: """Steps the E-NAUTILUS method.""" - # user = context.user # not used here db_session = context.db_session - problem_db = context.problem_db problem = Problem.from_problemdb(problem_db) interactive_session = context.interactive_session - parent_state = context.parent_state representative_solutions = db_session.exec( @@ -247,7 +241,10 @@ def get_representative( @router.post("/finalize") def finalize_enautilus( request: ENautilusFinalizeRequest, - context: Annotated[SessionContext, Depends(get_session_context)], + context: Annotated[ + SessionContext, + Depends(SessionContextGuard(require=[ContextField.PROBLEM, ContextField.PARENT_STATE])), + ], ) -> ENautilusFinalizeResponse: """Finalize E-NAUTILUS by selecting the final solution. @@ -365,8 +362,7 @@ def finalize_enautilus( @router.get("/session_tree/{session_id}") def get_session_tree( - session_id: int, - db_session: Annotated[Session, Depends(get_session)], + session_id: int, context: Annotated[SessionContext, Depends(SessionContextGuard())] ) -> ENautilusSessionTreeResponse: """Extract the full E-NAUTILUS decision tree for a session. @@ -375,11 +371,13 @@ def get_session_tree( Args: session_id: The interactive session ID. - db_session: The database session. + context: The context of the query. Returns: ENautilusSessionTreeResponse with nodes, edges, root_ids, and decision_events. """ + db_session = context.db_session + # Query step states step_stmt = ( select(StateDB) diff --git a/desdeo/api/routers/generic.py b/desdeo/api/routers/generic.py index 1ebcc1761..b7a5d01da 100644 --- a/desdeo/api/routers/generic.py +++ b/desdeo/api/routers/generic.py @@ -21,7 +21,7 @@ from desdeo.tools import SolverResults from desdeo.tools.score_bands import calculate_axes_positions, cluster, order_dimensions -from .utils import SessionContext, get_session_context +from .utils import ContextField, SessionContext, SessionContextGuard router = APIRouter(prefix="/method/generic") @@ -29,7 +29,10 @@ @router.post("/intermediate") def solve_intermediate( request: IntermediateSolutionRequest, - context: Annotated[SessionContext, Depends(get_session_context)], + context: Annotated[ + SessionContext, + Depends(SessionContextGuard(require=[ContextField.PROBLEM])) + ], ) -> GenericIntermediateSolutionResponse: """Solve intermediate solutions between given two solutions. @@ -82,13 +85,6 @@ def solve_intermediate( var_and_obj_values_of_references.append((var_values, obj_values)) - # Problem is now already loaded via context - if problem_db is None: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail=f"Problem with id={request.problem_id} could not be found.", - ) - problem = Problem.from_problemdb(problem_db) solver_results: list[SolverResults] = solve_intermediate_solutions( @@ -145,7 +141,6 @@ def solve_intermediate( ], ) - @router.post("/score-bands-obj-data") def calculate_score_bands_from_objective_data( request: ScoreBandsRequest, diff --git a/desdeo/api/routers/nimbus.py b/desdeo/api/routers/nimbus.py index d8b2a3f69..9373fddf8 100644 --- a/desdeo/api/routers/nimbus.py +++ b/desdeo/api/routers/nimbus.py @@ -39,7 +39,7 @@ from desdeo.problem import Problem from desdeo.tools import SolverResults -from .utils import SessionContext, get_session_context +from .utils import ContextField, SessionContext, SessionContextGuard router = APIRouter(prefix="/method/nimbus") @@ -108,7 +108,10 @@ def collect_all_solutions(user: User, problem_id: int, session: Session) -> list @router.post("/solve") def solve_solutions( request: NIMBUSClassificationRequest, - context: Annotated[SessionContext, Depends(get_session_context)], + context: Annotated[ + SessionContext, + Depends(SessionContextGuard(require=[ContextField.PROBLEM])) + ], ) -> NIMBUSClassificationResponse: """Solve the problem using the NIMBUS method.""" db_session = context.db_session @@ -117,12 +120,6 @@ def solve_solutions( interactive_session = context.interactive_session parent_state = context.parent_state - # Ensure problem exists - if problem_db is None: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, detail=f"Problem with id={request.problem_id} could not be found." - ) - solver = check_solver(problem_db=problem_db) problem = Problem.from_problemdb(problem_db) @@ -181,7 +178,10 @@ def solve_solutions( @router.post("/initialize") def initialize( request: NIMBUSInitializationRequest, - context: Annotated[SessionContext, Depends(get_session_context)], + context: Annotated[ + SessionContext, + Depends(SessionContextGuard(require=[ContextField.PROBLEM])) + ], ) -> NIMBUSInitializationResponse: """Initialize the problem for the NIMBUS method.""" db_session = context.db_session @@ -190,11 +190,6 @@ def initialize( interactive_session = context.interactive_session parent_state = context.parent_state - if problem_db is None: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, detail=f"Problem with id={request.problem_id} could not be found." - ) - solver = check_solver(problem_db=problem_db) problem = Problem.from_problemdb(problem_db) @@ -258,7 +253,8 @@ def initialize( @router.post("/save") def save( - request: NIMBUSSaveRequest, context: Annotated[SessionContext, Depends(get_session_context)] + request: NIMBUSSaveRequest, + context: Annotated[SessionContext, Depends(SessionContextGuard())] ) -> NIMBUSSaveResponse: """Save solutions.""" db_session = context.db_session @@ -332,7 +328,10 @@ def save( @router.post("/intermediate") def solve_nimbus_intermediate( request: IntermediateSolutionRequest, - context: Annotated[SessionContext, Depends(get_session_context)], + context: Annotated[ + SessionContext, + Depends(SessionContextGuard(require=[ContextField.PROBLEM])) + ], ) -> NIMBUSIntermediateSolutionResponse: """Solve intermediate solutions by forwarding the request to generic intermediate endpoint with context nimbus.""" db_session = context.db_session @@ -362,7 +361,10 @@ def solve_nimbus_intermediate( @router.post("/get-or-initialize") def get_or_initialize( request: NIMBUSInitializationRequest, - context: Annotated[SessionContext, Depends(get_session_context)], + context: Annotated[ + SessionContext, + Depends(SessionContextGuard(require=[ContextField.PROBLEM])) + ], ) -> ( NIMBUSInitializationResponse | NIMBUSClassificationResponse @@ -456,13 +458,14 @@ def get_or_initialize( @router.post("/finalize") def finalize_nimbus( - request: NIMBUSFinalizeRequest, context: Annotated[SessionContext, Depends(get_session_context)] + request: NIMBUSFinalizeRequest, + context: Annotated[SessionContext, Depends(SessionContextGuard(require=[ContextField.PROBLEM]))] ) -> NIMBUSFinalizeResponse: """An endpoint for finishing up the nimbus process. Args: request (NIMBUSFinalizeRequest): The request containing the final solution, etc. - context (Annotated[User, get_session_context): The current context. + context (Annotated[SessionContext, SessionContextGuard): The current context. Raises: HTTPException @@ -523,7 +526,7 @@ def finalize_nimbus( @router.post("/delete_save") def delete_save( request: NIMBUSDeleteSaveRequest, - context: Annotated[SessionContext, Depends(get_session_context)], + context: Annotated[SessionContext, Depends(SessionContextGuard())] ) -> NIMBUSDeleteSaveResponse: """Endpoint for deleting saved solutions. diff --git a/desdeo/api/routers/problem.py b/desdeo/api/routers/problem.py index 4d572f85b..3b8a297b5 100644 --- a/desdeo/api/routers/problem.py +++ b/desdeo/api/routers/problem.py @@ -5,7 +5,7 @@ from fastapi import APIRouter, Depends, HTTPException, Request, UploadFile, status from fastapi.responses import JSONResponse -from sqlmodel import Session, select +from sqlmodel import Session from desdeo.api.models import ( ForestProblemMetaData, @@ -30,7 +30,7 @@ from desdeo.problem import Problem from desdeo.tools.utils import available_solvers -from .utils import SessionContext, get_session_context, get_session_context_without_request +from .utils import ContextField, SessionContext, SessionContextGuard router = APIRouter(prefix="/problem") @@ -93,7 +93,7 @@ def get_problems_info(user: Annotated[User, Depends(get_current_user)]) -> list[ @router.post("/get") def get_problem( request: ProblemGetRequest, - context: Annotated[SessionContext, Depends(get_session_context)], + context: Annotated[SessionContext, Depends(SessionContextGuard(require=[ContextField.PROBLEM]))], ) -> ProblemInfo: """Get the model of a specific problem. @@ -107,22 +107,13 @@ def get_problem( Returns: ProblemInfo: detailed information on the requested problem. """ - problem_db = context.problem_db - - # Ensure problem exists - if problem_db is None: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail=f"The problem with the requested id={request.problem_id} was not found.", - ) - - return problem_db + return context.problem_db @router.post("/add") def add_problem( request: Annotated[Problem, Depends(parse_problem_json)], - context: Annotated[SessionContext, Depends(get_session_context_without_request)], + context: Annotated[SessionContext, Depends(SessionContextGuard())], ) -> ProblemInfo: """Add a newly defined problem to the database. @@ -166,7 +157,7 @@ def add_problem( @router.post("/add_json") def add_problem_json( json_file: UploadFile, - context: Annotated[SessionContext, Depends(get_session_context_without_request)], + context: Annotated[SessionContext, Depends(SessionContextGuard())], ) -> ProblemInfo: """Adds a problem to the database based on its JSON definition. @@ -207,7 +198,7 @@ def add_problem_json( @router.post("/get_metadata") def get_metadata( request: ProblemMetaDataGetRequest, - context: Annotated[SessionContext, Depends(get_session_context)], + context: Annotated[SessionContext, Depends(SessionContextGuard(require=[]))], ) -> list[ForestProblemMetaData | RepresentativeNonDominatedSolutions | SolverSelectionMetadata]: """Fetch specific metadata for a specific problem. @@ -224,17 +215,11 @@ def get_metadata( defined for the problem with the requested metadata type. If no match is found, returns an empty list. """ - db_session = context.db_session - - problem_from_db = db_session.exec(select(ProblemDB).where(ProblemDB.id == request.problem_id)).first() - - if problem_from_db is None: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail=f"Problem with ID {request.problem_id} not found!", - ) + problem_db = context.db_session.get(ProblemDB, request.problem_id) + if not problem_db: + raise HTTPException(status_code=404, detail=f"Problem with ID {request.problem_id} not found!") - problem_metadata = problem_from_db.problem_metadata + problem_metadata = problem_db.problem_metadata if problem_metadata is None: # no metadata define for the problem @@ -252,7 +237,7 @@ def get_available_solvers() -> list[str]: @router.post("/assign_solver") def select_solver( request: ProblemSelectSolverRequest, - context: Annotated[SessionContext, Depends(get_session_context)], + context: Annotated[SessionContext, Depends(SessionContextGuard(require=[ContextField.PROBLEM]))], ) -> JSONResponse: """Assign a specific solver for a problem. @@ -268,6 +253,7 @@ def select_solver( """ db_session = context.db_session user = context.user + problem_db = context.problem_db # guaranteed # Validate solver type if request.solver_string_representation not in [x for x, _ in available_solvers.items()]: @@ -276,15 +262,6 @@ def select_solver( status_code=status.HTTP_404_NOT_FOUND, ) - # Fetch problem - problem_db = db_session.exec(select(ProblemDB).where(ProblemDB.id == request.problem_id)).first() - - if problem_db is None: - raise HTTPException( - detail=f"No problem with ID {request.problem_id}!", - status_code=status.HTTP_404_NOT_FOUND, - ) - # Auth the user if user.id != problem_db.user_id: raise HTTPException( @@ -293,13 +270,10 @@ def select_solver( ) # All good, get on with it. - problem_metadata = problem_db.problem_metadata - if problem_metadata is None: - # There's no metadata for this problem! Create some. - problem_metadata = ProblemMetaDataDB(problem_id=problem_db.id, problem=problem_db) - db_session.add(problem_metadata) - db_session.commit() - db_session.refresh(problem_metadata) + problem_metadata = problem_db.problem_metadata or ProblemMetaDataDB(problem_id=problem_db.id, problem=problem_db) + db_session.add(problem_metadata) + db_session.commit() + db_session.refresh(problem_metadata) # Remove existing solver selection metadata if problem_metadata.solver_selection_metadata: @@ -328,8 +302,8 @@ def select_solver( @router.post("/add_representative_solution_set") def add_representative_solution_set( request: RepresentativeSolutionSetRequest, - context: Annotated[SessionContext, Depends(get_session_context)], -) -> RepresentativeSolutionSetInfo: + context: Annotated[SessionContext, Depends(SessionContextGuard(require=[ContextField.PROBLEM]))], +): """Add a new representative solution set as metadata to a problem. Args: @@ -346,17 +320,10 @@ def add_representative_solution_set( db_session: Session = context.db_session problem_db = context.problem_db - if problem_db is None: - raise HTTPException(status_code=500, detail="Problem context missing.") - - # Ensure metadata object exists - if problem_db.problem_metadata is None: - problem_metadata = ProblemMetaDataDB(problem_id=problem_db.id, problem=problem_db) - db_session.add(problem_metadata) - db_session.commit() - db_session.refresh(problem_metadata) - else: - problem_metadata = problem_db.problem_metadata + problem_metadata = problem_db.problem_metadata or ProblemMetaDataDB(problem_id=problem_db.id, problem=problem_db) + db_session.add(problem_metadata) + db_session.commit() + db_session.refresh(problem_metadata) # Add new representative solution set repr_metadata = RepresentativeNonDominatedSolutions( @@ -386,8 +353,8 @@ def add_representative_solution_set( @router.get("/all_representative_solution_sets/{problem_id}") def get_all_representative_solution_sets( problem_id: int, - context: Annotated[SessionContext, Depends(get_session_context_without_request)], -) -> list[RepresentativeSolutionSetInfo]: + context: Annotated[SessionContext, Depends(SessionContextGuard(require=[]))], +): """Get meta information about all representative solution sets for a given problem. Returns only name, description, ideal, and nadir for each set. @@ -426,9 +393,9 @@ def get_all_representative_solution_sets( @router.get("/representative_solution_set/{set_id}") def get_representative_solution_set( set_id: int, - context: Annotated[SessionContext, Depends(get_session_context_without_request)], -) -> RepresentativeSolutionSetFull: - """Fetch full information of a single representative solution by its ID.""" + context: Annotated[SessionContext, Depends(SessionContextGuard())], +): + """Fetch full information of a single representative solution set by its ID.""" db_session: Session = context.db_session # Fetch the representative set @@ -455,7 +422,7 @@ def get_representative_solution_set( @router.delete("/representative_solution_set/{set_id}", status_code=status.HTTP_204_NO_CONTENT) def delete_representative_solution_set( set_id: int, - context: Annotated[SessionContext, Depends(get_session_context_without_request)], + context: Annotated[SessionContext, Depends(SessionContextGuard())], ): """Delete a representative solution set by its ID.""" db_session: Session = context.db_session @@ -479,7 +446,7 @@ def delete_representative_solution_set( @router.delete("/{problem_id}", status_code=status.HTTP_204_NO_CONTENT) def delete_problem( problem_id: int, - context: Annotated[SessionContext, Depends(get_session_context_without_request)], + context: Annotated[SessionContext, Depends(SessionContextGuard())], ): """Delete a problem by its ID.""" db_session: Session = context.db_session @@ -499,7 +466,7 @@ def delete_problem( @router.get("/{problem_id}/json") def get_problem_json( problem_id: int, - context: Annotated[SessionContext, Depends(get_session_context_without_request)], + context: Annotated[SessionContext, Depends(SessionContextGuard())], ) -> JSONResponse: """Return a Problem as a serialized JSON object suitable for download/re-upload.""" db_session: Session = context.db_session diff --git a/desdeo/api/routers/reference_point_method.py b/desdeo/api/routers/reference_point_method.py index 2182cbe7a..4e3f91667 100644 --- a/desdeo/api/routers/reference_point_method.py +++ b/desdeo/api/routers/reference_point_method.py @@ -15,7 +15,7 @@ from desdeo.problem import Problem from desdeo.tools import SolverResults -from .utils import SessionContext, get_session_context +from .utils import ContextField, SessionContext, SessionContextGuard router = APIRouter(prefix="/method/rpm") @@ -23,7 +23,10 @@ @router.post("/solve") def solve_solutions( request: RPMSolveRequest, - context: Annotated[SessionContext, Depends(get_session_context)], + context: Annotated[ + SessionContext, + Depends(SessionContextGuard(require=[ContextField.PROBLEM])) + ], ) -> RPMState: """Runs an iteration of the reference point method. diff --git a/desdeo/api/routers/session.py b/desdeo/api/routers/session.py index 95e6c90ab..a5addcd81 100644 --- a/desdeo/api/routers/session.py +++ b/desdeo/api/routers/session.py @@ -14,7 +14,7 @@ User, ) from desdeo.api.routers.user_authentication import get_current_user -from desdeo.api.routers.utils import SessionContext, fetch_interactive_session, get_session_context_without_request +from desdeo.api.routers.utils import SessionContext, SessionContextGuard, fetch_interactive_session router = APIRouter(prefix="/session") @@ -22,7 +22,7 @@ @router.post("/new") def create_new_session( request: CreateSessionRequest, - context: Annotated[SessionContext, Depends(get_session_context_without_request)], + context: Annotated[SessionContext, Depends(SessionContextGuard())], ) -> InteractiveSessionInfo: """Creates a new interactive session.""" user = context.user @@ -44,7 +44,6 @@ def create_new_session( return interactive_session - @router.get("/get/{session_id}") def get_session( session_id: int, diff --git a/desdeo/api/routers/utils.py b/desdeo/api/routers/utils.py index 54fafafa8..86b52a65e 100644 --- a/desdeo/api/routers/utils.py +++ b/desdeo/api/routers/utils.py @@ -3,7 +3,9 @@ NOTE: No routers should be defined in this file! """ +from collections.abc import Iterable from dataclasses import dataclass +from enum import StrEnum from typing import Annotated from fastapi import Depends, HTTPException, status @@ -19,9 +21,10 @@ User, ) from desdeo.api.models.representative_solution import RepresentativeSolutionSetRequest +from desdeo.api.models.session import CreateSessionRequest from desdeo.api.routers.user_authentication import get_current_user -RequestType = RPMSolveRequest | ENautilusStepRequest | RepresentativeSolutionSetRequest +RequestType = RPMSolveRequest | ENautilusStepRequest | RepresentativeSolutionSetRequest | CreateSessionRequest def fetch_interactive_session(user: User, request: RequestType, session: Session) -> InteractiveSessionDB | None: @@ -145,6 +148,14 @@ def fetch_parent_state( return parent_state +class ContextField(StrEnum): + """Enum class to specify context fields.""" + + PROBLEM = "problem_db" + INTERACTIVE_SESSION = "interactive_session" + PARENT_STATE = "parent_state" + + @dataclass(frozen=True) class SessionContext: """A generic context to be used in various endpoints.""" @@ -156,38 +167,72 @@ class SessionContext: parent_state: StateDB | None = None -def get_session_context( - request: RequestType, - user: Annotated[User, Depends(get_current_user)], - db_session: Annotated[Session, Depends(get_session)], -) -> SessionContext: - """Gets the current session context. Should be used as a dep. - - Args: - request (RequestType): request based on which the context is fetched. - user (Annotated[User, Depends): the current user (dep). - db_session (Annotated[Session, Depends): the current database session (dep). +class SessionContextGuard: + """FastAPI dependency that builds a SessionContext and validates required fields.""" + + def __init__(self, require: Iterable[ContextField] | None = None): + """Init method for the SessionContextGuard class. + + Args: + require (Iterable[ContextField] | None, optional): fields that the guard will check + are included in the request. Defaults to None. + """ + self.require = set(require or []) + + def __call__( + self, + user: Annotated[User, Depends(get_current_user)], + db_session: Annotated[Session, Depends(get_session)], + request: RequestType | None = None, + ) -> SessionContext: + """Call method for the SessionContextGuard class. + + Args: + user (Annotated[User, Depends): the current user (dep) + db_session (Annotated[Session, Depends): the current database session (dep). + request (RequestType | None, optional): request based on which the context is fetched. + Defaults to None. + + Returns: + SessionContext: the session context with the required fields specified in `self.require`. + """ + problem_db = None + interactive_session = None + parent_state = None + + # Only fetch request-based context if request exists + if request is not None: + if hasattr(request, "problem_id"): + problem_db = fetch_user_problem(user, request, db_session) + + if hasattr(request, "interactive_session_id") or hasattr(request, "problem_id"): + interactive_session = fetch_interactive_session(user, request, db_session) + + if hasattr(request, "parent_state_id") or hasattr(request, "problem_id"): + parent_state = fetch_parent_state( + user, + request, + db_session, + interactive_session=interactive_session, + ) + + context = SessionContext( + user=user, + db_session=db_session, + problem_db=problem_db, + interactive_session=interactive_session, + parent_state=parent_state, + ) - Returns: - SessionContext: the current session context with the relevant instances - of `User`, `Session`, `ProblemDB`, `InteractiveSessionDB`, and `StateDB`. - """ - problem_db = fetch_user_problem(user, request, db_session) - interactive_session = fetch_interactive_session(user, request, db_session) - parent_state = fetch_parent_state(user, request, db_session, interactive_session=interactive_session) - - return SessionContext( - user=user, - db_session=db_session, - problem_db=problem_db, - interactive_session=interactive_session, - parent_state=parent_state, - ) + self._validate(context) + return context -def get_session_context_without_request( - user: Annotated[User, Depends(get_current_user)], - db_session: Annotated[Session, Depends(get_session)], -) -> SessionContext: - """Gets the current session context. Should be used as a dep.""" - return SessionContext(user=user, db_session=db_session) + def _validate(self, context: SessionContext) -> None: + """Ensure required fields exist.""" + for field in self.require: + if getattr(context, field.value) is None: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"{field} context missing.", + ) diff --git a/desdeo/api/routers/utopia.py b/desdeo/api/routers/utopia.py index bf612a1f2..a8b303bd6 100644 --- a/desdeo/api/routers/utopia.py +++ b/desdeo/api/routers/utopia.py @@ -16,7 +16,7 @@ UtopiaRequest, UtopiaResponse, ) -from desdeo.api.routers.utils import SessionContext, get_session_context +from desdeo.api.routers.utils import SessionContext, SessionContextGuard router = APIRouter(prefix="/utopia") @@ -24,14 +24,17 @@ @router.post("/") def get_utopia_data( # noqa: C901 request: UtopiaRequest, - context: Annotated[SessionContext, Depends(get_session_context)], + context: Annotated[ + SessionContext, + Depends(SessionContextGuard()) + ], ) -> UtopiaResponse: """Request and receive the Utopia map corresponding to the decision variables sent. Args: request (UtopiaRequest): the set of decision variables and problem for which the utopia forest map is requested for. - context (Annotated[SessionContext, Depends(get_session_context)]): the current session context + context (Annotated[SessionContext, Depends(SessionContextGuard)]): the current session context Raises: HTTPException: diff --git a/desdeo/api/tests/test_problem_metadata.py b/desdeo/api/tests/test_problem_metadata.py index 43a9b0748..3d623e250 100644 --- a/desdeo/api/tests/test_problem_metadata.py +++ b/desdeo/api/tests/test_problem_metadata.py @@ -77,24 +77,25 @@ def test_get_all_representative_solution_sets(client: TestClient, session_and_us session.commit() session.refresh(problem) + # Attach problem metadata + problem_metadata = ProblemMetaDataDB(problem_id=problem.id, problem=problem) + session.add(problem_metadata) + session.commit() + session.refresh(problem_metadata) + # Add a representative solution set solution_set = RepresentativeNonDominatedSolutions( metadata_id=None, + metadata_instance=None, name="Test Set GET", description="Description GET", solution_data={"x": [1.0, 2.0], "f": [0.1, 0.2]}, ideal={"f_1": 0.1}, nadir={"f_1": 0.2}, - metadata_instance=None ) - # Attach problem metadata - problem_metadata = ProblemMetaDataDB(problem_id=problem.id, problem=problem) - session.add(problem_metadata) - session.commit() - session.refresh(problem_metadata) - # Attach the representative set + solution_set.metadata_id = problem_metadata.id solution_set.metadata_instance = problem_metadata session.add(solution_set)