diff --git a/desdeo/api/models/nimbus.py b/desdeo/api/models/nimbus.py index 41a46aece..2db73564a 100644 --- a/desdeo/api/models/nimbus.py +++ b/desdeo/api/models/nimbus.py @@ -41,6 +41,7 @@ class NIMBUSDeleteSaveRequest(SQLModel): state_id : int = Field(description="The ID of the save state.") solution_index: int = Field(description="The ID of the solution within the above state.") + problem_id: int = Field(description="The ID of the problem.") class NIMBUSFinalizeRequest(SQLModel): @@ -103,7 +104,7 @@ class NIMBUSDeleteSaveResponse(SQLModel): response_type: str = "nimbus.delete_save" - message: str | None + message: str | None = None class NIMBUSFinalizeResponse(SQLModel): """The response from NIMBUS finish endpoint.""" diff --git a/desdeo/api/routers/emo.py b/desdeo/api/routers/emo.py index 6a941a985..e0f0c11a0 100644 --- a/desdeo/api/routers/emo.py +++ b/desdeo/api/routers/emo.py @@ -38,7 +38,7 @@ from desdeo.problem import Problem from desdeo.tools.score_bands import SCOREBandsConfig, SCOREBandsResult, score_json -from .utils import fetch_interactive_session, fetch_user_problem +from .utils import fetch_interactive_session, fetch_user_problem, get_session_context, SessionContext router = APIRouter(prefix="/method/emo", tags=["EMO"]) @@ -149,75 +149,55 @@ def get_templates() -> list[TemplateOptions]: templates.append(template.template) return templates - @router.post("/iterate") def iterate( request: EMOIterateRequest, - user: Annotated[User, Depends(get_current_user)], - session: Annotated[Session, Depends(get_session)], + context: Annotated[SessionContext, Depends(get_session_context)], ) -> EMOIterateResponse: - """Starts the EMO method. - - Args: - request (EMOSolveRequest): The request object containing parameters for the EMO method. - user (Annotated[User, Depends]): The current user. - session (Annotated[Session, Depends]): The database session. - - Raises: - HTTPException: If the request is invalid or the EMO method fails. + """Fetches results from a completed EMO method. - Returns: - IterateResponse: A response object containing a list of IDs to be used for websocket communication. - Also contains the StateDB id where the results will be stored. + Args: request (EMOIterateRequest): The request object containing parameters for fetching results. + context (Annotated[SessionContext, Depends]): The session context. """ - interactive_session: InteractiveSessionDB | None = fetch_interactive_session(user, request, session) + # 1) Get context objects + db_session = context.db_session + interactive_session = context.interactive_session + parent_state = context.parent_state + + # 2) Ensure problem exists + if context.problem_db is None: + raise HTTPException(status_code=404, detail="Problem not found") - problem_db = fetch_user_problem(user, request, session) + problem_db = context.problem_db problem = Problem.from_problemdb(problem_db) - templates = request.template_options + # 3) Templates + templates = request.template_options or get_templates() - if templates is None: - templates = get_templates() + web_socket_ids = [ + f"{template.algorithm_name.lower()}_{datetime.now().strftime('%Y%m%d%H%M%S%f')}" + for template in templates + ] - web_socket_ids = [] - for template in templates: - # Ensure unique names - web_socket_ids.append(f"{template.algorithm_name.lower()}_{datetime.now().strftime('%Y%m%d%H%M%S%f')}") client_id = f"client_{datetime.now().strftime('%Y%m%d%H%M%S%f')}" - client_id = "client" - - # Save request (incomplete and EAs have not finished running yet) - - # Handle parent state - if request.parent_state_id is None: - parent_state = None - else: - statement = select(StateDB).where(StateDB.id == request.parent_state_id) - parent_state = session.exec(statement).first() - - if parent_state is None: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail=f"Could not find state with id={request.parent_state_id}", - ) + # 4) Create incomplete state emo_iterate_state = EMOIterateState( template_options=jsonable_encoder(templates), preference_options=jsonable_encoder(request.preference_options), ) incomplete_db_state = StateDB.create( - database_session=session, + database_session=db_session, problem_id=problem_db.id, session_id=interactive_session.id if interactive_session else None, parent_id=parent_state.id if parent_state else None, state=emo_iterate_state, ) - session.add(incomplete_db_state) - session.commit() - session.refresh(incomplete_db_state) + db_session.add(incomplete_db_state) + db_session.commit() + db_session.refresh(incomplete_db_state) state_id = incomplete_db_state.id if state_id is None: @@ -225,10 +205,8 @@ def iterate( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed to create a new state in the database.", ) - # Close db session - session.close() - # Spawn a new process to handle EMO method creation + # 5) Start process Process( target=_spawn_emo_process, args=( @@ -384,70 +362,66 @@ async def _ea_async( # noqa: PLR0913 await ws.send(f'{{"message": "Finished {websocket_id}", "send_to": "{client_id}"}}') results_dict[websocket_id] = results - @router.post("/fetch") async def fetch_results( request: EMOFetchRequest, - user: Annotated[User, Depends(get_current_user)], - session: Annotated[Session, Depends(get_session)], + context: Annotated[SessionContext, Depends(get_session_context)], ) -> StreamingResponse: """Fetches results from a completed EMO method. Args: request (EMOFetchRequest): The request object containing parameters for fetching results. - user (Annotated[User, Depends]): The current user. - session (Annotated[Session, Depends]): The database session. + context (Annotated[SessionContext, Depends]): The session context. - Raises: - HTTPException: If the request is invalid or the EMO method has not completed. + Raises: HTTPException: If the request is invalid or the EMO method has not completed. - Returns: - StreamingResponse: A streaming response containing the results of the EMO method. + Returns: StreamingResponse: A streaming response containing the results of the EMO method. """ - parent_state = request.parent_state_id - statement = select(StateDB).where(StateDB.id == parent_state) - state = session.exec(statement).first() + # Use context instead of manual fetch + state = context.parent_state + if state is None: raise HTTPException(status_code=404, detail="Parent state not found.") if not isinstance(state.state, EMOIterateState): - raise TypeError(f"State with id={parent_state} is not of type EMOIterateState.") + raise TypeError(f"State with id={request.parent_state_id} is not of type EMOIterateState.") if not (state.state.objective_values and state.state.decision_variables): - raise ValueError(f"State does not contain results yet.") + raise ValueError("State does not contain results yet.") # Convert objs: dict[str, list[float]] to objs: list[dict[str, float]] raw_objs: dict[str, list[float]] = state.state.objective_values n_solutions = len(next(iter(raw_objs.values()))) - objs: list[dict[str, float]] = [{k: v[i] for k, v in raw_objs.items()} for i in range(n_solutions)] + objs: list[dict[str, float]] = [ + {k: v[i] for k, v in raw_objs.items()} for i in range(n_solutions) + ] raw_decs: dict[str, list[float]] = state.state.decision_variables - - decs: list[dict[str, float]] = [{k: v[i] for k, v in raw_decs.items()} for i in range(n_solutions)] - - response: list[Solution] = [] + decs: list[dict[str, float]] = [ + {k: v[i] for k, v in raw_decs.items()} for i in range(n_solutions) + ] def result_stream(): for i in range(n_solutions): - item = {"solution_id": i, "objective_values": objs[i], "decision_variables": decs[i]} + item = { + "solution_id": i, + "objective_values": objs[i], + "decision_variables": decs[i], + } yield json.dumps(item) + "\n" return StreamingResponse(result_stream()) - @router.post("/fetch_score") async def fetch_score_bands( request: EMOScoreRequest, - user: Annotated[User, Depends(get_current_user)], - session: Annotated[Session, Depends(get_session)], + context: Annotated[SessionContext, Depends(get_session_context)], ) -> EMOScoreResponse: """Fetches results from a completed EMO method. - Args: - request (EMOFetchRequest): The request object containing parameters for fetching results and of the SCORE bands - visualization. - user (Annotated[User, Depends]): The current user. - session (Annotated[Session, Depends]): The database session. + Args: request (EMOFetchRequest): The request object containing parameters for fetching + results and of the SCORE bands visualization. + context (Annotated[SessionContext, Depends]): The session context. Raises: HTTPException: If the request is invalid or the EMO method has not completed. @@ -455,23 +429,25 @@ async def fetch_score_bands( Returns: SCOREBandsResult: The results of the SCORE bands visualization. """ - if request.config is None: - score_config = SCOREBandsConfig() - else: - score_config = request.config - parent_state = request.parent_state_id - statement = select(StateDB).where(StateDB.id == parent_state) - state = session.exec(statement).first() + # Use context instead of manual fetch + state = context.parent_state + db_session = context.db_session + problem_db = context.problem_db + if state is None: raise HTTPException(status_code=404, detail="Parent state not found.") if not isinstance(state.state, EMOIterateState): - raise TypeError(f"State with id={parent_state} is not of type EMOIterateState.") + raise TypeError(f"State with id={request.parent_state_id} is not of type EMOIterateState.") if not (state.state.objective_values and state.state.decision_variables): - raise ValueError(f"State does not contain results yet.") + raise ValueError("State does not contain results yet.") + + if request.config is None: + score_config = SCOREBandsConfig() + else: + score_config = request.config - # Convert objs: dict[str, list[float]] to objs: list[dict[str, float]] raw_objs: dict[str, list[float]] = state.state.objective_values objs = pl.DataFrame(raw_objs) @@ -482,16 +458,19 @@ async def fetch_score_bands( score_state = EMOSCOREState(result=results.model_dump()) + # Use the session + problem from context instead of request directly score_db_state = StateDB.create( - database_session=session, - problem_id=request.problem_id, - session_id=request.session_id, - parent_id=parent_state, + database_session=db_session, + problem_id=problem_db.id, + session_id=state.session_id, + parent_id=state.id, state=score_state, ) - session.add(score_db_state) - session.commit() - session.refresh(score_db_state) + + db_session.add(score_db_state) + db_session.commit() + db_session.refresh(score_db_state) + state_id = score_db_state.id return EMOScoreResponse(result=results, state_id=state_id) diff --git a/desdeo/api/routers/generic.py b/desdeo/api/routers/generic.py index bc0208889..bed234def 100644 --- a/desdeo/api/routers/generic.py +++ b/desdeo/api/routers/generic.py @@ -26,44 +26,52 @@ from desdeo.tools import SolverResults from desdeo.tools.score_bands import calculate_axes_positions, cluster, order_dimensions +from .utils import get_session_context, SessionContext router = APIRouter(prefix="/method/generic") - @router.post("/intermediate") def solve_intermediate( request: IntermediateSolutionRequest, - user: Annotated[User, Depends(get_current_user)], - session: Annotated[Session, Depends(get_session)], + context: Annotated[SessionContext, Depends(get_session_context)], ) -> GenericIntermediateSolutionResponse: - """Solve intermediate solutions between given two solutions.""" - if request.session_id is not None: - statement = select(InteractiveSessionDB).where(InteractiveSessionDB.id == request.session_id) - interactive_session = session.exec(statement) - - if interactive_session is None: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail=f"Could not find interactive session with id={request.session_id}.", - ) - else: - # request.session_id is None: - # use active session instead - statement = select(InteractiveSessionDB).where(InteractiveSessionDB.id == user.active_session_id) - - interactive_session = session.exec(statement).first() + """Solve intermediate solutions between given two solutions. + + Args: + request (IntermediateSolutionRequest): The request object containing parameters + for fetching results. + context (Annotated[SessionContext, Depends]): The session context. + """ + db_session = context.db_session + user = context.user + problem_db = context.problem_db + interactive_session = context.interactive_session + parent_state = context.parent_state + + # -------------------------------------- + # Validate interactive session + # -------------------------------------- + if interactive_session is None and request.session_id is not None: + # session id was explicitly requested but not found + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Could not find interactive session with id={request.session_id}.", + ) + # -------------------------------------- # query both reference solutions' variable values - # stored as lit of tuples, first element of each tuple are variables values, second are objective function values + # -------------------------------------- var_and_obj_values_of_references: list[tuple[dict, dict]] = [] reference_states = [] + for solution_info in [request.reference_solution_1, request.reference_solution_2]: - solution_state = session.exec(select(StateDB).where(StateDB.id == solution_info.state_id)).first() + solution_state = db_session.exec( + select(StateDB).where(StateDB.id == solution_info.state_id) + ).first() if solution_state is None: - # no StateDB found with the given id raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, - detail=f"Could not find a state with the given id{solution_state.state_id}.", + detail=f"Could not find a state with id={solution_info.state_id}.", ) reference_states.append(solution_state) @@ -71,7 +79,6 @@ def solve_intermediate( try: _var_values = solution_state.state.result_variable_values var_values = _var_values[solution_info.solution_index] - except IndexError as exc: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, @@ -83,7 +90,6 @@ def solve_intermediate( try: _obj_values = solution_state.state.result_objective_values obj_values = _obj_values[solution_info.solution_index] - except IndexError as exc: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, @@ -94,10 +100,9 @@ def solve_intermediate( var_and_obj_values_of_references.append((var_values, obj_values)) - # fetch the problem from the DB - statement = select(ProblemDB).where(ProblemDB.user_id == user.id, ProblemDB.id == request.problem_id) - problem_db = session.exec(statement).first() - + # -------------------------------------- + # Problem is now already loaded via context + # -------------------------------------- if problem_db is None: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, @@ -116,26 +121,9 @@ def solve_intermediate( solver_options=request.solver_options, ) - # fetch parent state - if request.parent_state_id is None: - # parent state is assumed to be the last state added to the session. - parent_state = ( - interactive_session.states[-1] - if (interactive_session is not None and len(interactive_session.states) > 0) - else None - ) - - else: - # request.parent_state_id is not None - statement = session.select(StateDB).where(StateDB.id == request.parent_state_id) - parent_state = session.exec(statement).first() - - if parent_state is None: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail=f"Could not find state with id={request.parent_state_id}", - ) - + # -------------------------------------- + # parent_state is already loaded in context + # -------------------------------------- intermediate_state = IntermediateSolutionState( scalarization_options=request.scalarization_options, context=request.context, @@ -149,16 +137,16 @@ def solve_intermediate( # create DB state and add it to the DB state = StateDB.create( - database_session=session, + database_session=db_session, problem_id=problem_db.id, session_id=interactive_session.id if interactive_session is not None else None, parent_id=parent_state.id if parent_state is not None else None, state=intermediate_state, ) - session.add(state) - session.commit() - session.refresh(state) + db_session.add(state) + db_session.commit() + db_session.refresh(state) return GenericIntermediateSolutionResponse( state_id=state.id, @@ -177,7 +165,6 @@ def solve_intermediate( ], ) - @router.post("/score-bands-obj-data") def calculate_score_bands_from_objective_data( request: ScoreBandsRequest, diff --git a/desdeo/api/routers/nimbus.py b/desdeo/api/routers/nimbus.py index b84ceab65..63921a639 100644 --- a/desdeo/api/routers/nimbus.py +++ b/desdeo/api/routers/nimbus.py @@ -43,9 +43,9 @@ from desdeo.problem import Problem from desdeo.tools import SolverResults +from .utils import get_session_context, SessionContext router = APIRouter(prefix="/method/nimbus") - # helper for collecting solutions def filter_duplicates(solutions: list[SavedSolutionReference]) -> list[SavedSolutionReference]: """Filters out the duplicate values of objectives.""" @@ -106,62 +106,30 @@ def collect_all_solutions(user: User, problem_id: int, session: Session) -> list return filter_duplicates(all_solutions) - @router.post("/solve") def solve_solutions( request: NIMBUSClassificationRequest, - user: Annotated[User, Depends(get_current_user)], - session: Annotated[Session, Depends(get_session)], + context: Annotated[SessionContext, Depends(get_session_context)], ) -> NIMBUSClassificationResponse: """Solve the problem using the NIMBUS method.""" - if request.session_id is not None: - statement = select(InteractiveSessionDB).where(InteractiveSessionDB.id == request.session_id) - interactive_session = session.exec(statement) - - if interactive_session is None: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail=f"Could not find interactive session with id={request.session_id}.", - ) - else: - # request.session_id is None: - # use active session instead - statement = select(InteractiveSessionDB).where(InteractiveSessionDB.id == user.active_session_id) - - interactive_session = session.exec(statement).first() - - # fetch the problem from the DB - statement = select(ProblemDB).where(ProblemDB.user_id == user.id, ProblemDB.id == request.problem_id) - problem_db = session.exec(statement).first() - + db_session = context.db_session + user = context.user + problem_db = context.problem_db + interactive_session = context.interactive_session + parent_state = context.parent_state + + # ----------------------------- + # Ensure problem exists + # ----------------------------- if problem_db is None: raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, detail=f"Problem with id={request.problem_id} could not be found." + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Problem with id={request.problem_id} could not be found." ) solver = check_solver(problem_db=problem_db) - problem = Problem.from_problemdb(problem_db) - # fetch parent state - if request.parent_state_id is None: - # parent state is assumed to be the last state added to the session. - parent_state = ( - interactive_session.states[-1] - if (interactive_session is not None and len(interactive_session.states) > 0) - else None - ) - - else: - # request.parent_state_id is not None - statement = select(StateDB).where(StateDB.id == request.parent_state_id) - parent_state = session.exec(statement).first() - - if parent_state is None: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, detail=f"Could not find state with id={request.parent_state_id}" - ) - solver_results: list[SolverResults] = solve_sub_problems( problem=problem, current_objectives=request.current_objectives, @@ -185,24 +153,24 @@ def solve_solutions( # create DB state and add it to the DB state = StateDB.create( - database_session=session, + database_session=db_session, problem_id=problem_db.id, session_id=interactive_session.id if interactive_session is not None else None, parent_id=parent_state.id if parent_state is not None else None, state=nimbus_state, ) - session.add(state) - session.commit() - session.refresh(state) + db_session.add(state) + db_session.commit() + db_session.refresh(state) # Collect all current solutions current_solutions: list[SolutionReference] = [] for i, _ in enumerate(solver_results): current_solutions.append(SolutionReference(state=state, solution_index=i)) - saved_solutions = collect_saved_solutions(user, request.problem_id, session) - all_solutions = collect_all_solutions(user, request.problem_id, session) + saved_solutions = collect_saved_solutions(user, request.problem_id, db_session) + all_solutions = collect_all_solutions(user, request.problem_id, db_session) return NIMBUSClassificationResponse( state_id=state.id, @@ -217,60 +185,41 @@ def solve_solutions( @router.post("/initialize") def initialize( request: NIMBUSInitializationRequest, - user: Annotated[User, Depends(get_current_user)], - session: Annotated[Session, Depends(get_session)], + context: Annotated[SessionContext, Depends(get_session_context)], ) -> NIMBUSInitializationResponse: """Initialize the problem for the NIMBUS method.""" - if request.session_id is not None: - statement = select(InteractiveSessionDB).where(InteractiveSessionDB.id == request.session_id) - interactive_session = session.exec(statement) - - if interactive_session is None: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail=f"Could not find interactive session with id={request.session_id}.", - ) - else: - # request.session_id is None: - # use active session instead - statement = select(InteractiveSessionDB).where(InteractiveSessionDB.id == user.active_session_id) - - interactive_session = session.exec(statement).first() - - print(interactive_session) - - # fetch the problem from the DB - statement = select(ProblemDB).where(ProblemDB.user_id == user.id, ProblemDB.id == request.problem_id) - problem_db = session.exec(statement).first() + db_session = context.db_session + user = context.user + problem_db = context.problem_db + interactive_session = context.interactive_session + parent_state = context.parent_state if problem_db is None: raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, detail=f"Problem with id={request.problem_id} could not be found." + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Problem with id={request.problem_id} could not be found." ) solver = check_solver(problem_db=problem_db) - problem = Problem.from_problemdb(problem_db) if isinstance(ref_point := request.starting_point, ReferencePoint): - # ReferencePoint starting_point = ref_point.aspiration_levels elif isinstance(info := request.starting_point, SolutionInfo): - # SolutionInfo # fetch the solution statement = select(StateDB).where(StateDB.id == info.state_id) - state = session.exec(statement).first() + state = db_session.exec(statement).first() if state is None: raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, detail=f"StateDB with index {info.state_id} could not be found." + status_code=status.HTTP_404_NOT_FOUND, + detail=f"StateDB with index {info.state_id} could not be found." ) starting_point = state.state.result_objective_values[info.solution_index] else: - # if not starting point is provided, generate it starting_point = None start_result = generate_starting_point( @@ -281,18 +230,6 @@ def initialize( solver_options=request.solver_options, ) - # fetch parent state if it is given - if request.parent_state_id is None: - parent_state = None - else: - statement = session.select(StateDB).where(StateDB.id == request.parent_state_id) - parent_state = session.exec(statement).first() - - if parent_state is None: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, detail=f"Could not find state with id={request.parent_state_id}" - ) - initialization_state = NIMBUSInitializationState( reference_point=starting_point, scalarization_options=request.scalarization_options, @@ -302,20 +239,20 @@ def initialize( # create DB state and add it to the DB state = StateDB.create( - database_session=session, + database_session=db_session, problem_id=problem_db.id, - session_id=interactive_session.id if interactive_session is not None else None, - parent_id=parent_state.id if parent_state is not None else None, + session_id=interactive_session.id if interactive_session else None, + parent_id=parent_state.id if parent_state else None, state=initialization_state, ) - session.add(state) - session.commit() - session.refresh(state) + db_session.add(state) + db_session.commit() + db_session.refresh(state) current_solutions = [SolutionReference(state=state, solution_index=0)] - saved_solutions = collect_saved_solutions(user, request.problem_id, session) - all_solutions = collect_all_solutions(user, request.problem_id, session) + saved_solutions = collect_saved_solutions(user, request.problem_id, db_session) + all_solutions = collect_all_solutions(user, request.problem_id, db_session) return NIMBUSInitializationResponse( state_id=state.id, @@ -324,47 +261,31 @@ def initialize( all_solutions=all_solutions, ) - @router.post("/save") def save( request: NIMBUSSaveRequest, - user: Annotated[User, Depends(get_current_user)], - session: Annotated[Session, Depends(get_session)], + context: Annotated[SessionContext, Depends(get_session_context)] ) -> NIMBUSSaveResponse: """Save solutions.""" - if request.session_id is not None: - statement = select(InteractiveSessionDB).where(InteractiveSessionDB.id == request.session_id) - interactive_session = session.exec(statement) + db_session = context.db_session + user = context.user + interactive_session = context.interactive_session + parent_state = context.parent_state - if interactive_session is None: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail=f"Could not find interactive session with id={request.session_id}.", - ) - else: - # request.session_id is None: - # use active session instead - statement = select(InteractiveSessionDB).where(InteractiveSessionDB.id == user.active_session_id) - - interactive_session = session.exec(statement).first() - - # fetch parent state + # fetch parent state (same logic, but using context) if request.parent_state_id is None: - # parent state is assumed to be the last state added to the session. parent_state = ( interactive_session.states[-1] if (interactive_session is not None and len(interactive_session.states) > 0) else None ) - else: - # request.parent_state_id is not None - statement = select(StateDB).where(StateDB.id == request.parent_state_id) - parent_state = session.exec(statement).first() + parent_state = db_session.exec(select(StateDB).where(StateDB.id == request.parent_state_id)).first() if parent_state is None: raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, detail=f"Could not find state with id={request.parent_state_id}" + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Could not find state with id={request.parent_state_id}" ) # Check for duplicate solutions and update names instead of saving duplicates @@ -372,7 +293,7 @@ def save( new_solutions: list[UserSavedSolutionDB] = [] for info in request.solution_info: - existing_solution = session.exec( + existing_solution = db_session.exec( select(UserSavedSolutionDB).where( UserSavedSolutionDB.origin_state_id == info.state_id, UserSavedSolutionDB.solution_index == info.solution_index, @@ -380,63 +301,58 @@ def save( ).first() if existing_solution is not None: - # Update the name of the existing solution existing_solution.name = info.name - - session.add(existing_solution) - + db_session.add(existing_solution) updated_solutions.append(existing_solution) + else: - # This is a new solution new_solution = UserSavedSolutionDB.from_state_info( - session, user.id, request.problem_id, info.state_id, info.solution_index, info.name + db_session, user.id, request.problem_id, info.state_id, info.solution_index, info.name ) - session.add(new_solution) - + db_session.add(new_solution) new_solutions.append(new_solution) # Commit existing and new solutions - if updated_solutions or new_solution: - session.commit() - [session.refresh(row) for row in updated_solutions + new_solutions] + if updated_solutions or new_solutions: + db_session.commit() + [db_session.refresh(row) for row in updated_solutions + new_solutions] - # save solver results for state in SolverResults format just for consistency (dont save name field to state) + # save solver results for state in SolverResults format just for consistency save_state = NIMBUSSaveState(solutions=updated_solutions + new_solutions) # create DB state state = StateDB.create( - database_session=session, + database_session=db_session, problem_id=request.problem_id, session_id=interactive_session.id if interactive_session is not None else None, parent_id=parent_state.id if parent_state is not None else None, state=save_state, ) - session.add(state) - session.commit() - session.refresh(state) + db_session.add(state) + db_session.commit() + db_session.refresh(state) return NIMBUSSaveResponse(state_id=state.id) - @router.post("/intermediate") def solve_nimbus_intermediate( request: IntermediateSolutionRequest, - user: Annotated[User, Depends(get_current_user)], - session: Annotated[Session, Depends(get_session)], + context: Annotated[SessionContext, Depends(get_session_context)], ) -> NIMBUSIntermediateSolutionResponse: """Solve intermediate solutions by forwarding the request to generic intermediate endpoint with context nimbus.""" + db_session = context.db_session + user = context.user + # Add NIMBUS context to request request.context = "nimbus" - # Forward to generic endpoint - intermediate_response = solve_intermediate(request, user, session) - # Get saved solutions for this user and problem - saved_solutions = collect_saved_solutions(user, request.problem_id, session) + # Forward to generic endpoint + intermediate_response = solve_intermediate(request, context) - # Get all solutions including the newly generated intermediate ones - all_solutions = collect_all_solutions(user, request.problem_id, session) + saved_solutions = collect_saved_solutions(user, request.problem_id, db_session) + all_solutions = collect_all_solutions(user, request.problem_id, db_session) return NIMBUSIntermediateSolutionResponse( state_id=intermediate_response.state_id, @@ -447,28 +363,16 @@ def solve_nimbus_intermediate( all_solutions=all_solutions, ) - @router.post("/get-or-initialize") def get_or_initialize( request: NIMBUSInitializationRequest, - user: Annotated[User, Depends(get_current_user)], - session: Annotated[Session, Depends(get_session)], + context: Annotated[SessionContext, Depends(get_session_context)], ) -> NIMBUSInitializationResponse | NIMBUSClassificationResponse | \ NIMBUSIntermediateSolutionResponse | NIMBUSFinalizeResponse: - """Get the latest NIMBUS state if it exists, or initialize a new one if it doesn't.""" - if request.session_id is not None: - statement = select(InteractiveSessionDB).where(InteractiveSessionDB.id == request.session_id) - interactive_session = session.exec(statement) - if interactive_session is None: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail=f"Could not find interactive session with id={request.session_id}.", - ) - else: - # use active session instead - statement = select(InteractiveSessionDB).where(InteractiveSessionDB.id == user.active_session_id) - interactive_session = session.exec(statement).first() + db_session = context.db_session + user = context.user + interactive_session = context.interactive_session # Look for latest relevant state in the session statement = ( @@ -479,9 +383,8 @@ def get_or_initialize( ) .order_by(StateDB.id.desc()) ) - states = session.exec(statement).all() + states = db_session.exec(statement).all() - # Find the latest relevant state (NIMBUS classification, initialization, or intermediate with NIMBUS context) latest_state = None for state in states: if isinstance(state.state, (NIMBUSClassificationState | NIMBUSInitializationState | NIMBUSFinalState)) or ( @@ -491,17 +394,15 @@ def get_or_initialize( break if latest_state is not None: - saved_solutions = collect_saved_solutions(user, request.problem_id, session) - all_solutions = collect_all_solutions(user, request.problem_id, session) - # Handle both single result and list of results cases + saved_solutions = collect_saved_solutions(user, request.problem_id, db_session) + all_solutions = collect_all_solutions(user, request.problem_id, db_session) + solver_results = latest_state.state.solver_results - if isinstance(solver_results, list): - current_solutions = [ - SolutionReference(state=latest_state, solution_index=i) for i in range(len(solver_results)) - ] - else: - # Single result case (NIMBUSInitializationState) - current_solutions = [SolutionReference(state=latest_state, solution_index=0)] + current_solutions = ( + [SolutionReference(state=latest_state, solution_index=i) for i in range(len(solver_results))] + if isinstance(solver_results, list) + else [SolutionReference(state=latest_state, solution_index=0)] + ) if isinstance(latest_state.state, NIMBUSClassificationState): return NIMBUSClassificationResponse( @@ -524,7 +425,6 @@ def get_or_initialize( ) if isinstance(latest_state.state, NIMBUSFinalState): - solution_index = latest_state.state.solution_result_index origin_state_id = latest_state.state.solution_origin_state_id @@ -542,7 +442,6 @@ def get_or_initialize( all_solutions=all_solutions, ) - # NIMBUSInitializationState return NIMBUSInitializationResponse( state_id=latest_state.id, current_solutions=current_solutions, @@ -551,69 +450,42 @@ def get_or_initialize( ) # No relevant state found, initialize a new one - return initialize(request, user, session) - + return initialize(request, context) @router.post("/finalize") def finalize_nimbus( request: NIMBUSFinalizeRequest, - user: Annotated[User, Depends(get_current_user)], - session: Annotated[Session, Depends(get_session)] + context: Annotated[SessionContext, Depends(get_session_context)] ) -> NIMBUSFinalizeResponse: - """An endpoint for finishing up the nimbus process. - - Args: - request (NIMBUSFinalizeRequest): The request containing the final solution, etc. - user (Annotated[User, Depends): The current user. - session (Annotated[Session, Depends): The database session. - Raises: - HTTPException - - Returns: - NIMBUSFinalizeResponse: Response containing info on the final solution. - """ - if request.session_id is not None: - statement = select(InteractiveSessionDB).where(InteractiveSessionDB.id == request.session_id) - interactive_session = session.exec(statement) - - if interactive_session is None: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail=f"Could not find interactive session with id={request.session_id}.", - ) - else: - # request.session_id is None: - # use active session instead - statement = select(InteractiveSessionDB).where(InteractiveSessionDB.id == user.active_session_id) - - interactive_session = session.exec(statement).first() + db_session = context.db_session + user = context.user + interactive_session = context.interactive_session + parent_state = context.parent_state if request.parent_state_id is None: parent_state = None else: - statement = session.select(StateDB).where(StateDB.id == request.parent_state_id) - parent_state = session.exec(statement).first() - + parent_state = db_session.exec(select(StateDB).where(StateDB.id == request.parent_state_id)).first() if parent_state is None: raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, detail=f"Could not find state with id={request.parent_state_id}" + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Could not find state with id={request.parent_state_id}" ) - # fetch the problem from the DB - statement = select(ProblemDB).where(ProblemDB.user_id == user.id, ProblemDB.id == request.problem_id) - problem_db = session.exec(statement).first() - + # fetch problem + problem_db = db_session.exec(select(ProblemDB).where(ProblemDB.user_id == user.id, ProblemDB.id == request.problem_id)).first() if problem_db is None: raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, detail=f"Problem with id={request.problem_id} could not be found." + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Problem with id={request.problem_id} could not be found." ) solution_state_id = request.solution_info.state_id solution_index = request.solution_info.solution_index - statement = select(StateDB).where(StateDB.id == solution_state_id) - actual_state = session.exec(statement).first().state + state = db_session.exec(select(StateDB).where(StateDB.id == solution_state_id)).first() + actual_state = state.state if state else None if actual_state is None: raise HTTPException( detail="No concrete substate!", @@ -627,18 +499,18 @@ def finalize_nimbus( ) state = StateDB.create( - database_session=session, + database_session=db_session, problem_id=problem_db.id, session_id=interactive_session.id if interactive_session is not None else None, parent_id=parent_state.id if parent_state is not None else None, state=final_state, ) - session.add(state) - session.commit() - session.refresh(state) + db_session.add(state) + db_session.commit() + db_session.refresh(state) - solution_reference_response=SolutionReferenceResponse( + solution_reference_response = SolutionReferenceResponse( solution_index=solution_index, state_id=solution_state_id, objective_values=final_state.solver_results.optimal_objectives, @@ -648,22 +520,20 @@ def finalize_nimbus( return NIMBUSFinalizeResponse( state_id=state.id, final_solution=solution_reference_response, - saved_solutions=collect_saved_solutions(user=user, problem_id=problem_db.id, session=session), - all_solutions=collect_all_solutions(user=user, problem_id=problem_db.id, session=session), + saved_solutions=collect_saved_solutions(user=user, problem_id=problem_db.id, session=db_session), + all_solutions=collect_all_solutions(user=user, problem_id=problem_db.id, session=db_session), ) @router.post("/delete_save") def delete_save( request: NIMBUSDeleteSaveRequest, - user: Annotated[User, Depends(get_current_user)], - session: Annotated[Session, Depends(get_session)] + context: Annotated[SessionContext, Depends(get_session_context)], ) -> NIMBUSDeleteSaveResponse: """Endpoint for deleting saved solutions. Args: request (NIMBUSDeleteSaveRequest): request containing necessary information for deleting a save - user (Annotated[User, Depends): the current (logged in) user - session (Annotated[Session, Depends): database session + context (Annotated[SessionContext, Depends): session context Raises: HTTPException @@ -671,7 +541,9 @@ def delete_save( Returns: NIMBUSDeleteSaveResponse: Response acknowledging the deletion of save and other useful info. """ - to_be_deleted = session.exec( + db_session = context.db_session + + to_be_deleted = db_session.exec( select(UserSavedSolutionDB).where( UserSavedSolutionDB.origin_state_id == request.state_id, UserSavedSolutionDB.solution_index == request.solution_index, @@ -684,10 +556,10 @@ def delete_save( status_code=status.HTTP_404_NOT_FOUND ) - session.delete(to_be_deleted) - session.commit() + db_session.delete(to_be_deleted) + db_session.commit() - to_be_deleted = session.exec( + to_be_deleted = db_session.exec( select(UserSavedSolutionDB).where( UserSavedSolutionDB.origin_state_id == request.state_id, UserSavedSolutionDB.solution_index == request.solution_index, @@ -702,4 +574,4 @@ def delete_save( return NIMBUSDeleteSaveResponse( message="Save deleted." - ) + ) \ No newline at end of file diff --git a/desdeo/api/routers/problem.py b/desdeo/api/routers/problem.py index 8e8a93dd8..1bf2a145f 100644 --- a/desdeo/api/routers/problem.py +++ b/desdeo/api/routers/problem.py @@ -25,6 +25,7 @@ from desdeo.api.routers.user_authentication import get_current_user from desdeo.problem import Problem from desdeo.tools.utils import available_solvers +from .utils import get_session_context, get_session_context_base, SessionContext router = APIRouter(prefix="/problem") @@ -83,63 +84,48 @@ def get_problems_info(user: Annotated[User, Depends(get_current_user)]) -> list[ """ return user.problems - @router.post("/get") def get_problem( request: ProblemGetRequest, - user: Annotated[User, Depends(get_current_user)], - session: Annotated[Session, Depends(get_session)], + context: Annotated[SessionContext, Depends(get_session_context)], ) -> ProblemInfo: """Get the model of a specific problem. - Args: - request (ProblemGetRequest): the request containing the problem's id `problem_id`. + Args: request (ProblemGetRequest): the request containing the problem's id `problem_id`. user (Annotated[User, Depends): the current user. session (Annotated[Session, Depends): the database session. - - Raises: - HTTPException: could not find a problem with the given id. - - Returns: - ProblemInfo: detailed information on the requested problem. + Raises: HTTPException: could not find a problem with the given id. + Returns: ProblemInfo: detailed information on the requested problem. """ - problem = session.get(ProblemDB, request.problem_id) + db_session = context.db_session + problem_db = context.problem_db - if problem is None: + # ----------------------------- + # Ensure problem exists + # ----------------------------- + if problem_db is None: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail=f"The problem with the requested id={request.problem_id} was not found.", ) - return problem - + return problem_db @router.post("/add") def add_problem( request: Annotated[Problem, Depends(parse_problem_json)], - user: Annotated[User, Depends(get_current_user)], - session: Annotated[Session, Depends(get_session)], + context: Annotated[SessionContext, Depends(get_session_context_base)], ) -> ProblemInfo: - """Add a newly defined problem to the database. - - Args: - request (Problem): the JSON representation of the problem. - user (Annotated[User, Depends): the current user. - session (Annotated[Session, Depends): the database session. - - Note: - Users with the role 'guest' may not add new problems. + """Add a newly defined problem to the database.""" + user = context.user + db_session = context.db_session - Raises: - HTTPException: when any issue with defining the problem arises. - - Returns: - ProblemInfo: the information about the problem added. - """ if user.role == UserRole.guest: raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, detail="Guest users are not allowed to add new problems." + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Guest users are not allowed to add new problems.", ) + try: problem_db = ProblemDB.from_problem(request, user=user) except Exception as e: @@ -148,92 +134,68 @@ def add_problem( detail=f"Could not add problem. Possible reason: {e!r}", ) from e - session.add(problem_db) - session.commit() - session.refresh(problem_db) + db_session.add(problem_db) + db_session.commit() + db_session.refresh(problem_db) return problem_db - @router.post("/add_json") def add_problem_json( json_file: UploadFile, - user: Annotated[User, Depends(get_current_user)], - session: Annotated[Session, Depends(get_session)], + context: Annotated[SessionContext, Depends(get_session_context_base)], ) -> ProblemInfo: - """Adds a problem to the database based on its JSON definition. + """Adds a problem to the database based on its JSON definition.""" + user = context.user + db_session = context.db_session - Args: - json_file (UploadFile): a file in JSON format describing the problem. - user (Annotated[User, Depends): the usr for which the problem is added. - session (Annotated[Session, Depends): the database session. - - Raises: - HTTPException: if the provided `json_file` is empty. - HTTPException: if the content in the provided `json_file` is not in JSON format.__annotations__ - - Returns: - ProblemInfo: a description of the added problem. - """ raw = json_file.file.read() if not raw: - raise HTTPException(400, "Empty upload.") + raise HTTPException(status_code=400, detail="Empty upload.") try: - # for extra validation - json.loads(raw) + json.loads(raw) # extra validation except json.JSONDecodeError as e: - raise HTTPException(400, "Invalid JSON.") from e + raise HTTPException(status_code=400, detail="Invalid JSON.") from e problem = Problem.model_validate_json(raw, by_name=True) problem_db = ProblemDB.from_problem(problem, user=user) - session.add(problem_db) - session.commit() - session.refresh(problem_db) + db_session.add(problem_db) + db_session.commit() + db_session.refresh(problem_db) return problem_db - @router.post("/get_metadata") def get_metadata( request: ProblemMetaDataGetRequest, - user: Annotated[User, Depends(get_current_user)], - session: Annotated[Session, Depends(get_session)], + context: Annotated[SessionContext, Depends(get_session_context)], ) -> list[ForestProblemMetaData | RepresentativeNonDominatedSolutions | SolverSelectionMetadata]: - """Fetch specific metadata for a specific problem. + """Fetch specific metadata for a specific problem.""" + db_session = context.db_session - Fetch specific metadata for a specific problem. See all the possible - metadata types from DESDEO/desdeo/api/models/problem.py Problem Metadata - section. + problem_from_db = db_session.exec( + select(ProblemDB).where(ProblemDB.id == request.problem_id) + ).first() - Args: - request (MetaDataGetRequest): the requested metadata type. - user (Annotated[User, Depends]): the current user. - session (Annotated[Session, Depends]): the database session. - - Returns: - list[ForestProblemMetadata | RepresentativeNonDominatedSolutions]: list containing all the metadata - defined for the problem with the requested metadata type. If no match is found, - returns an empty list. - """ - statement = select(ProblemDB).where(ProblemDB.id == request.problem_id) - problem_from_db = session.exec(statement).first() if problem_from_db is None: raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, detail=f"Problem with ID {request.problem_id} not found!" + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Problem with ID {request.problem_id} not found!", ) problem_metadata = problem_from_db.problem_metadata if problem_metadata is None: - # no metadata define for the problem return [] - # metadata is defined, try to find matching types based on request - return [metadata for metadata in problem_metadata.all_metadata if metadata.metadata_type == request.metadata_type] - + return [ + metadata + for metadata in problem_metadata.all_metadata + if metadata.metadata_type == request.metadata_type + ] @router.get("/assign/solver", response_model=list[str]) def get_available_solvers() -> list[str]: @@ -243,65 +205,64 @@ def get_available_solvers() -> list[str]: @router.post("/assign_solver") def select_solver( request: ProblemSelectSolverRequest, - user: Annotated[User, Depends(get_current_user)], - session: Annotated[Session, Depends(get_session)], + context: Annotated[SessionContext, Depends(get_session_context)], ) -> JSONResponse: - """Assign a specific solver for a problem. - - request: ProblemSelectSolverRequest: The request containing problem id and string representation of the solver - user: Annotated[User, Depends(get_current_user): The user that is logged in. - session: Annotated[Session, Depends(get_session)]: The database session. - - Raises: - HTTPException: Unknown solver, unauthorized user + """Assign a specific solver for a problem.""" + db_session = context.db_session + user = context.user - Returns: - JSONResponse: A simple confirmation. - """ + # Validate solver type if request.solver_string_representation not in [x for x, _ in available_solvers.items()]: raise HTTPException( detail=f"Solver of unknown type: {request.solver_string_representation}", status_code=status.HTTP_404_NOT_FOUND, ) - """Set a specific solver for a specific problem.""" - # Get the problem - problem_db = session.exec(select(ProblemDB).where(ProblemDB.id == request.problem_id)).first() + # Fetch problem + problem_db = db_session.exec( + select(ProblemDB).where(ProblemDB.id == request.problem_id) + ).first() + if problem_db is None: raise HTTPException( detail=f"No problem with ID {request.problem_id}!", status_code=status.HTTP_404_NOT_FOUND, ) - # Auth the user + + # Authorization if user.id != problem_db.user_id: - raise HTTPException(detail="Unauthorized user!", status_code=status.HTTP_401_UNAUTHORIZED) + raise HTTPException( + detail="Unauthorized user!", + status_code=status.HTTP_401_UNAUTHORIZED, + ) - # All good, get on with it. + # Ensure metadata exists problem_metadata = problem_db.problem_metadata if problem_metadata is None: - # There's no metadata for this problem! Create some. problem_metadata = ProblemMetaDataDB(problem_id=problem_db.id, problem=problem_db) - session.add(problem_metadata) - session.commit() - session.refresh(problem_metadata) + db_session.add(problem_metadata) + db_session.commit() + db_session.refresh(problem_metadata) + # Remove existing solver selection metadata if problem_metadata.solver_selection_metadata: - session.delete(problem_metadata.solver_selection_metadata[-1]) - session.commit() + db_session.delete(problem_metadata.solver_selection_metadata[-1]) + db_session.commit() + # Add new solver selection metadata solver_selection_metadata = SolverSelectionMetadata( metadata_id=problem_metadata.id, solver_string_representation=request.solver_string_representation, metadata_instance=problem_metadata, ) - session.add(solver_selection_metadata) - session.commit() - session.refresh(solver_selection_metadata) + db_session.add(solver_selection_metadata) + db_session.commit() + db_session.refresh(solver_selection_metadata) problem_metadata.solver_selection_metadata.append(solver_selection_metadata) - session.add(problem_metadata) - session.commit() - session.refresh(problem_metadata) + db_session.add(problem_metadata) + db_session.commit() + db_session.refresh(problem_metadata) return JSONResponse(content={"message": "OK"}, status_code=status.HTTP_200_OK) diff --git a/desdeo/api/routers/reference_point_method.py b/desdeo/api/routers/reference_point_method.py index 11dd452ec..68d2d49b1 100644 --- a/desdeo/api/routers/reference_point_method.py +++ b/desdeo/api/routers/reference_point_method.py @@ -2,7 +2,7 @@ from typing import Annotated -from fastapi import APIRouter, Depends +from fastapi import APIRouter, Depends, HTTPException, status from sqlmodel import Session from desdeo.api.db import get_session @@ -21,39 +21,32 @@ from desdeo.problem import Problem from desdeo.tools import SolverResults -from .utils import fetch_interactive_session, fetch_parent_state, fetch_user_problem +from .utils import fetch_interactive_session, fetch_parent_state, fetch_user_problem, get_session_context, SessionContext router = APIRouter(prefix="/method/rpm") - @router.post("/solve") def solve_solutions( request: RPMSolveRequest, - user: Annotated[User, Depends(get_current_user)], - session: Annotated[Session, Depends(get_session)], + context: Annotated[SessionContext, Depends(get_session_context)], ) -> RPMState: - """Runs an iteration of the reference point method. - - Args: - request (RPMSolveRequest): a request with the needed information to run the method. - user (Annotated[User, Depends): the current user. - session (Annotated[Session, Depends): the current database session. - - Returns: - RPMState: a state with information on the results of iterating the reference point method - once. - """ - # fetch interactive session, parent state, and ProblemDB - interactive_session: InteractiveSessionDB = fetch_interactive_session(user, request, session) - parent_state = fetch_parent_state(user, request, session, interactive_session) - - problem_db: ProblemDB = fetch_user_problem(user, request, session) + """Runs an iteration of the reference point method.""" + user = context.user + db_session = context.db_session + problem_db = context.problem_db + interactive_session = context.interactive_session + parent_state = context.parent_state + + # sanity check (defensive, but explicit) + if problem_db is None: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Problem context missing.", + ) solver = check_solver(problem_db=problem_db) - problem = Problem.from_problemdb(problem_db) - # optimize for solutions solver_results: list[SolverResults] = rpm_solve_solutions( problem, request.preference.aspiration_levels, @@ -62,14 +55,17 @@ def solve_solutions( request.solver_options, ) - # create DB preference - preference_db = PreferenceDB(user_id=user.id, problem_id=problem_db.id, preference=request.preference) + preference_db = PreferenceDB( + user_id=user.id, + problem_id=problem_db.id, + preference=request.preference, + ) - session.add(preference_db) - session.commit() - session.refresh(preference_db) + db_session.add(preference_db) + db_session.commit() + db_session.refresh(preference_db) - # create state and add to DB + # create RPM state (API model) rpm_state = RPMState( scalarization_options=request.scalarization_options, solver=request.solver, @@ -77,17 +73,17 @@ def solve_solutions( solver_results=solver_results, ) - # create DB state and add it to the DB + # create DB state state = StateDB( problem_id=problem_db.id, preference_id=preference_db.id, - session_id=interactive_session.id if interactive_session is not None else None, - parent_id=parent_state.id if parent_state is not None else None, + session_id=interactive_session.id if interactive_session else None, + parent_id=parent_state.id if parent_state else None, state=rpm_state, ) - session.add(state) - session.commit() - session.refresh(state) + db_session.add(state) + db_session.commit() + db_session.refresh(state) - return rpm_state + return rpm_state \ No newline at end of file diff --git a/desdeo/api/routers/session.py b/desdeo/api/routers/session.py index 166ad49ac..2421c0444 100644 --- a/desdeo/api/routers/session.py +++ b/desdeo/api/routers/session.py @@ -14,33 +14,34 @@ User, ) from desdeo.api.routers.user_authentication import get_current_user -from desdeo.api.routers.utils import fetch_interactive_session +from desdeo.api.routers.utils import fetch_interactive_session, get_session_context_base, SessionContext router = APIRouter(prefix="/session") - @router.post("/new") def create_new_session( request: CreateSessionRequest, - user: Annotated[User, Depends(get_current_user)], - session: Annotated[Session, Depends(get_db_session)], + context: Annotated[SessionContext, Depends(get_session_context_base)], ) -> InteractiveSessionInfo: - """.""" - interactive_session = InteractiveSessionDB(user_id=user.id, info=request.info) + user = context.user + db_session = context.db_session - session.add(interactive_session) - session.commit() - session.refresh(interactive_session) + interactive_session = InteractiveSessionDB( + user_id=user.id, + info=request.info, + ) + + db_session.add(interactive_session) + db_session.commit() + db_session.refresh(interactive_session) user.active_session_id = interactive_session.id - session.add(user) - session.commit() - session.refresh(interactive_session) + db_session.add(user) + db_session.commit() return interactive_session - @router.get("/get/{session_id}") def get_session( session_id: int, diff --git a/desdeo/api/routers/utils.py b/desdeo/api/routers/utils.py index 8ce72305b..c2f2cef43 100644 --- a/desdeo/api/routers/utils.py +++ b/desdeo/api/routers/utils.py @@ -4,7 +4,7 @@ """ from dataclasses import dataclass -from typing import Annotated +from typing import Annotated, Optional from fastapi import Depends, HTTPException, status from sqlmodel import Session, select @@ -20,8 +20,11 @@ ) from desdeo.api.routers.user_authentication import get_current_user -RequestType = RPMSolveRequest | ENautilusStepRequest +# --------------------------------------------------------------------- +# Request protocol used by session utilities +# --------------------------------------------------------------------- +RequestType = RPMSolveRequest | ENautilusStepRequest def fetch_interactive_session(user: User, request: RequestType, session: Session) -> InteractiveSessionDB | None: """Gets the desired instance of `InteractiveSessionDB`. @@ -67,7 +70,7 @@ def fetch_interactive_session(user: User, request: RequestType, session: Session return interactive_session -def fetch_user_problem(user: User, request: RequestType, session: Session) -> ProblemDB: +def fetch_user_problem(user: User, request: RequestType, session: Session) -> ProblemDB | None: """Fetches a user's `ProblemDB` based on the id in the given request. Args: @@ -75,48 +78,36 @@ def fetch_user_problem(user: User, request: RequestType, session: Session) -> Pr request (RequestType): request containing details of the problem to be fetched (`request.problem_id`). session (Session): the database session from which to fetch the problem. - Raises: - HTTPException: a problem with the given id (`request.problem_id`) could not be found (404). - - Returns: - Problem: the instance of `ProblemDB` with the given id. + Returns None if no problem is found. """ - statement = select(ProblemDB).where(ProblemDB.user_id == user.id, ProblemDB.id == request.problem_id) - problem_db = session.exec(statement).first() - - if problem_db is None: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, detail=f"Problem with id={request.problem_id} could not be found." - ) - - return problem_db + if request.problem_id is None: + return None + statement = select(ProblemDB).where( + ProblemDB.user_id == user.id, + ProblemDB.id == request.problem_id, + ) + return session.exec(statement).first() def fetch_parent_state( - user: User, request: RequestType, session: Session, interactive_session: InteractiveSessionDB | None = None + user: User, + request: RequestType, + session: Session, + interactive_session: InteractiveSessionDB | None = None, ) -> StateDB | None: - """Fetches the parent state, if an id is given, or if defined in the given interactive session. - - Determines the appropriate parent `StateDB` instance to associate with a new - state or operation. It first checks whether the `request` explicitly - provides a `parent_state_id`. If so, it attempts to retrieve the - corresponding `StateDB` entry from the database. If no such id is provided, - the function defaults to returning the most recently added state from the - given `interactive_session`, if available. If neither source provides a - parent state, `None` is returned. - + """Fetches the parent state if defined. Args: - user (User): the user for which the parent state is fetched. - request (RequestType): request containing details about the parent state and optionally the - interactive session. - session (Session): the database session from which to fetch the parent state. - interactive_session (InteractiveSessionDB | None, optional): the interactive session containing - information about the parent state. Defaults to None. + user (User): the user for which the parent state is fetched. + request (RequestType): request containing details about the parent state and optionally the + interactive session. + session (Session): the database session from which to fetch the parent state. + interactive_session (InteractiveSessionDB | None, optional): the interactive session containing + information about the parent state. Defaults to None. Raises: - HTTPException: when `request.parent_state_id` is not `None` and a `StateDB` with this id cannot - be found in the given database session. + HTTPException: when `request.parent_state_id` is not `None` and a `StateDB` with this id cannot + be found in the given database session. Returns: StateDB | None: if `request.parent_state_id` is given, returns the corresponding `StateDB`. @@ -124,38 +115,32 @@ def fetch_parent_state( If both `request.parent_state_id` and `interactive_session` are `None`, then returns `None`. """ if request.parent_state_id is None: - # parent state is assumed to be the last sate added to the session. - # if `interactive_session` is None, then parent state is set to None. - parent_state = ( + return ( interactive_session.states[-1] - if (interactive_session is not None and len(interactive_session.states) > 0) + if interactive_session and interactive_session.states else None ) - else: - # request.parent_state_id is not None - statement = select(StateDB).where(StateDB.id == request.parent_state_id) - parent_state = session.exec(statement).first() + statement = select(StateDB).where(StateDB.id == request.parent_state_id) + parent_state = session.exec(statement).first() - # this error is raised because if a parent_state_id is given, it is assumed that the - # user wished to use that state explicitly as the parent. - if parent_state is None: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, detail=f"Could not find state with id={request.parent_state_id}" - ) + if parent_state is None: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Could not find state with id={request.parent_state_id}", + ) return parent_state - @dataclass(frozen=True) class SessionContext: """A generic context to be used in various endpoints.""" user: User db_session: Session - problem_db: ProblemDB - interactive_session: InteractiveSessionDB | None - parent_state: StateDB | None + problem_db: Optional[ProblemDB] = None + interactive_session: Optional[InteractiveSessionDB] = None, + parent_state: Optional[StateDB] = None, def get_session_context( @@ -185,3 +170,11 @@ def get_session_context( interactive_session=interactive_session, parent_state=parent_state, ) + +def get_session_context_base( + user: Annotated[User, Depends(get_current_user)], + db_session: Annotated[Session, Depends(get_session)], +) -> SessionContext: + """Gets the current session context. Should be used as a dep.""" + return SessionContext(user=user, db_session=db_session) + diff --git a/desdeo/api/routers/utopia.py b/desdeo/api/routers/utopia.py index 7db3d9005..2abba3e51 100644 --- a/desdeo/api/routers/utopia.py +++ b/desdeo/api/routers/utopia.py @@ -19,6 +19,7 @@ UtopiaResponse, ) from desdeo.api.routers.user_authentication import get_current_user +from desdeo.api.routers.utils import get_session_context_base, SessionContext router = APIRouter(prefix="/utopia") @@ -26,21 +27,21 @@ @router.post("/") def get_utopia_data( request: UtopiaRequest, - user: Annotated[User, Depends(get_current_user)], - session: Annotated[Session, Depends(get_session)], + context: Annotated[SessionContext, Depends(get_session_context_base)], ) -> UtopiaResponse: """Request and receive the Utopia map corresponding to the decision variables sent. Args: request (UtopiaRequest): the set of decision variables and problem for which the utopia forest map is requested for. - user (Annotated[User, Depend(get_current_user)]) the current user - session (Annotated[Session, Depends(get_session)]) the current database session + context (Annotated[SessionContext, Depends(get_session_context_base)]) the current session context Raises: HTTPException: Returns: UtopiaResponse: the map for the forest, to be rendered in frontend """ + session = context.db_session + empty_response = UtopiaResponse(is_utopia=False, map_name="", map_json={}, options={}, description="", years=[]) state = session.exec(select(StateDB).where(StateDB.id == request.solution.state_id)).first() @@ -227,4 +228,4 @@ def treatment_index(part: str) -> str: map_json=json.loads(forest_metadata.map_json), description=map_description, years=forest_metadata.years, - ) + ) \ No newline at end of file diff --git a/desdeo/api/tests/test_routes.py b/desdeo/api/tests/test_routes.py index 62349eed5..15b90762d 100644 --- a/desdeo/api/tests/test_routes.py +++ b/desdeo/api/tests/test_routes.py @@ -649,7 +649,7 @@ def test_nimbus_save_and_delete_save(client: TestClient): assert len(solve_result.saved_solutions) > 0 # 4. Delete save - request: NIMBUSDeleteSaveRequest = NIMBUSDeleteSaveRequest(state_id=2, solution_index=1) + request: NIMBUSDeleteSaveRequest = NIMBUSDeleteSaveRequest(state_id=2, solution_index=1, problem_id=1) response = post_json(client, "/method/nimbus/delete_save", request.model_dump(), access_token) delete_save_result: NIMBUSDeleteSaveResponse = NIMBUSDeleteSaveResponse.model_validate(json.loads(response.content))