diff --git a/desdeo/api/routers/emo.py b/desdeo/api/routers/emo.py index 72d7c9b80..03275dd25 100644 --- a/desdeo/api/routers/emo.py +++ b/desdeo/api/routers/emo.py @@ -15,30 +15,24 @@ from fastapi import APIRouter, Depends, HTTPException, WebSocket, WebSocketDisconnect, status from fastapi.encoders import jsonable_encoder from fastapi.responses import StreamingResponse -from sqlmodel import Session, select +from sqlmodel import select from websockets.asyncio.client import connect from desdeo.api.db import get_session -from desdeo.api.models import InteractiveSessionDB, StateDB +from desdeo.api.models import StateDB from desdeo.api.models.emo import ( EMOFetchRequest, - EMOFetchResponse, EMOIterateRequest, EMOIterateResponse, - EMOSaveRequest, EMOScoreRequest, EMOScoreResponse, - Solution, ) -from desdeo.api.models.problem import ProblemDB -from desdeo.api.models.state import EMOFetchState, EMOIterateState, EMOSaveState, EMOSCOREState -from desdeo.api.models.user import User -from desdeo.api.routers.user_authentication import get_current_user +from desdeo.api.models.state import EMOIterateState, EMOSCOREState from desdeo.emo.options.templates import EMOOptions, PreferenceOptions, TemplateOptions, emo_constructor from desdeo.problem import Problem -from desdeo.tools.score_bands import SCOREBandsConfig, SCOREBandsResult, score_json +from desdeo.tools.score_bands import SCOREBandsConfig, score_json -from .utils import fetch_interactive_session, fetch_user_problem, get_session_context, SessionContext +from .utils import SessionContext, get_session_context router = APIRouter(prefix="/method/emo", tags=["EMO"]) @@ -113,7 +107,6 @@ async def websocket_endpoint( try: while True: data = await websocket.receive_json() - # print(data) if "send_to" in data: try: await ws_manager.send_private_message(data, data["send_to"]) @@ -149,6 +142,7 @@ def get_templates() -> list[TemplateOptions]: templates.append(template.template) return templates + @router.post("/iterate") def iterate( request: EMOIterateRequest, @@ -179,8 +173,7 @@ def iterate( templates = request.template_options or get_templates() web_socket_ids = [ - f"{template.algorithm_name.lower()}_{datetime.now().strftime('%Y%m%d%H%M%S%f')}" - for template in templates + f"{template.algorithm_name.lower()}_{datetime.now().strftime('%Y%m%d%H%M%S%f')}" for template in templates ] client_id = f"client_{datetime.now().strftime('%Y%m%d%H%M%S%f')}" @@ -301,7 +294,7 @@ def _spawn_emo_process( # noqa: PLR0913 session.close() -def _ea_sync( # noqa: PLR0913 +def _ea_sync( problem: Problem, template: TemplateOptions, preference_options: PreferenceOptions | None, @@ -334,7 +327,7 @@ def _ea_sync( # noqa: PLR0913 ) -async def _ea_async( # noqa: PLR0913 +async def _ea_async( problem: Problem, websocket_id: str, client_id: str, @@ -366,6 +359,7 @@ async def _ea_async( # noqa: PLR0913 await ws.send(f'{{"message": "Finished {websocket_id}", "send_to": "{client_id}"}}') results_dict[websocket_id] = results + @router.post("/fetch") async def fetch_results( request: EMOFetchRequest, @@ -396,14 +390,10 @@ async def fetch_results( # Convert objs: dict[str, list[float]] to objs: list[dict[str, float]] raw_objs: dict[str, list[float]] = state.state.objective_values n_solutions = len(next(iter(raw_objs.values()))) - objs: list[dict[str, float]] = [ - {k: v[i] for k, v in raw_objs.items()} for i in range(n_solutions) - ] + objs: list[dict[str, float]] = [{k: v[i] for k, v in raw_objs.items()} for i in range(n_solutions)] raw_decs: dict[str, list[float]] = state.state.decision_variables - decs: list[dict[str, float]] = [ - {k: v[i] for k, v in raw_decs.items()} for i in range(n_solutions) - ] + decs: list[dict[str, float]] = [{k: v[i] for k, v in raw_decs.items()} for i in range(n_solutions)] def result_stream(): for i in range(n_solutions): @@ -416,6 +406,7 @@ def result_stream(): return StreamingResponse(result_stream()) + @router.post("/fetch_score") async def fetch_score_bands( request: EMOScoreRequest, @@ -434,22 +425,22 @@ async def fetch_score_bands( SCOREBandsResult: The results of the SCORE bands visualization. """ # Use context instead of manual fetch - state = context.parent_state + parent_state = context.parent_state db_session = context.db_session problem_db = context.problem_db - if state is None: + if parent_state is None: raise HTTPException(status_code=404, detail="Parent state not found.") - if not isinstance(state.state, EMOIterateState): + if not isinstance(parent_state.state, EMOIterateState): raise TypeError(f"State with id={request.parent_state_id} is not of type EMOIterateState.") - if not (state.state.objective_values and state.state.decision_variables): + if not (parent_state.state.objective_values and parent_state.state.decision_variables): raise ValueError("State does not contain results yet.") score_config = SCOREBandsConfig() if request.config is None else request.config - raw_objs: dict[str, list[float]] = state.state.objective_values + raw_objs: dict[str, list[float]] = parent_state.state.objective_values objs = pl.DataFrame(raw_objs) results = score_json( @@ -463,8 +454,8 @@ async def fetch_score_bands( score_db_state = StateDB.create( database_session=db_session, problem_id=problem_db.id, - session_id=state.session_id, - parent_id=state.id, + session_id=parent_state.session_id, + parent_id=parent_state.id, state=score_state, ) diff --git a/desdeo/api/routers/generic.py b/desdeo/api/routers/generic.py index ac4d76f7c..3b7afcbda 100644 --- a/desdeo/api/routers/generic.py +++ b/desdeo/api/routers/generic.py @@ -5,30 +5,27 @@ import numpy as np import pandas as pd from fastapi import APIRouter, Depends, HTTPException, status -from sqlmodel import Session, select +from sqlmodel import select -from desdeo.api.db import get_session from desdeo.api.models import ( - InteractiveSessionDB, IntermediateSolutionRequest, IntermediateSolutionState, - ProblemDB, ScoreBandsRequest, ScoreBandsResponse, SolutionReference, StateDB, - User, ) from desdeo.api.models.generic import GenericIntermediateSolutionResponse -from desdeo.api.routers.user_authentication import get_current_user from desdeo.mcdm.nimbus import solve_intermediate_solutions from desdeo.problem import Problem from desdeo.tools import SolverResults from desdeo.tools.score_bands import calculate_axes_positions, cluster, order_dimensions -from .utils import get_session_context, SessionContext +from .utils import SessionContext, get_session_context + router = APIRouter(prefix="/method/generic") + @router.post("/intermediate") def solve_intermediate( request: IntermediateSolutionRequest, @@ -42,7 +39,6 @@ def solve_intermediate( context (Annotated[SessionContext, Depends]): The session context. """ db_session = context.db_session - user = context.user # noqa: F841 problem_db = context.problem_db interactive_session = context.interactive_session parent_state = context.parent_state @@ -64,9 +60,7 @@ def solve_intermediate( reference_states = [] for solution_info in [request.reference_solution_1, request.reference_solution_2]: - solution_state = db_session.exec( - select(StateDB).where(StateDB.id == solution_info.state_id) - ).first() + solution_state = db_session.exec(select(StateDB).where(StateDB.id == solution_info.state_id)).first() if solution_state is None: raise HTTPException( @@ -165,6 +159,7 @@ def solve_intermediate( ], ) + @router.post("/score-bands-obj-data") def calculate_score_bands_from_objective_data( request: ScoreBandsRequest, diff --git a/desdeo/api/routers/nimbus.py b/desdeo/api/routers/nimbus.py index 10aa2f4a3..564ca9fba 100644 --- a/desdeo/api/routers/nimbus.py +++ b/desdeo/api/routers/nimbus.py @@ -6,9 +6,7 @@ from numpy import allclose from sqlmodel import Session, select -from desdeo.api.db import get_session from desdeo.api.models import ( - InteractiveSessionDB, IntermediateSolutionRequest, NIMBUSClassificationRequest, NIMBUSClassificationResponse, @@ -38,14 +36,15 @@ from desdeo.api.models.state import IntermediateSolutionState from desdeo.api.routers.generic import solve_intermediate from desdeo.api.routers.problem import check_solver -from desdeo.api.routers.user_authentication import get_current_user from desdeo.mcdm.nimbus import generate_starting_point, solve_sub_problems from desdeo.problem import Problem from desdeo.tools import SolverResults -from .utils import get_session_context, SessionContext +from .utils import SessionContext, get_session_context + router = APIRouter(prefix="/method/nimbus") + # helper for collecting solutions def filter_duplicates(solutions: list[SavedSolutionReference]) -> list[SavedSolutionReference]: """Filters out the duplicate values of objectives.""" @@ -106,6 +105,7 @@ def collect_all_solutions(user: User, problem_id: int, session: Session) -> list return filter_duplicates(all_solutions) + @router.post("/solve") def solve_solutions( request: NIMBUSClassificationRequest, @@ -123,8 +123,7 @@ def solve_solutions( # ----------------------------- if problem_db is None: raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail=f"Problem with id={request.problem_id} could not be found." + status_code=status.HTTP_404_NOT_FOUND, detail=f"Problem with id={request.problem_id} could not be found." ) solver = check_solver(problem_db=problem_db) @@ -196,8 +195,7 @@ def initialize( if problem_db is None: raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail=f"Problem with id={request.problem_id} could not be found." + status_code=status.HTTP_404_NOT_FOUND, detail=f"Problem with id={request.problem_id} could not be found." ) solver = check_solver(problem_db=problem_db) @@ -213,8 +211,7 @@ def initialize( if state is None: raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail=f"StateDB with index {info.state_id} could not be found." + status_code=status.HTTP_404_NOT_FOUND, detail=f"StateDB with index {info.state_id} could not be found." ) starting_point = state.state.result_objective_values[info.solution_index] @@ -261,10 +258,10 @@ def initialize( all_solutions=all_solutions, ) + @router.post("/save") def save( - request: NIMBUSSaveRequest, - context: Annotated[SessionContext, Depends(get_session_context)] + request: NIMBUSSaveRequest, context: Annotated[SessionContext, Depends(get_session_context)] ) -> NIMBUSSaveResponse: """Save solutions.""" db_session = context.db_session @@ -272,7 +269,6 @@ def save( interactive_session = context.interactive_session parent_state = context.parent_state - # fetch parent state (same logic, but using context) if request.parent_state_id is None: parent_state = ( interactive_session.states[-1] @@ -284,8 +280,7 @@ def save( if parent_state is None: raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail=f"Could not find state with id={request.parent_state_id}" + status_code=status.HTTP_404_NOT_FOUND, detail=f"Could not find state with id={request.parent_state_id}" ) # Check for duplicate solutions and update names instead of saving duplicates @@ -336,6 +331,7 @@ def save( return NIMBUSSaveResponse(state_id=state.id) + @router.post("/intermediate") def solve_nimbus_intermediate( request: IntermediateSolutionRequest, @@ -365,12 +361,17 @@ def solve_nimbus_intermediate( all_solutions=all_solutions, ) + @router.post("/get-or-initialize") def get_or_initialize( request: NIMBUSInitializationRequest, context: Annotated[SessionContext, Depends(get_session_context)], -) -> NIMBUSInitializationResponse | NIMBUSClassificationResponse | \ - NIMBUSIntermediateSolutionResponse | NIMBUSFinalizeResponse: +) -> ( + NIMBUSInitializationResponse + | NIMBUSClassificationResponse + | NIMBUSIntermediateSolutionResponse + | NIMBUSFinalizeResponse +): """Get the latest NIMBUS state if it exists, or initialize a new one if it doesn't.""" db_session = context.db_session user = context.user @@ -435,7 +436,7 @@ def get_or_initialize( solution_index=solution_index, state_id=origin_state_id, objective_values=latest_state.state.solver_results.optimal_objectives, - variable_values=latest_state.state.solver_results.optimal_variables + variable_values=latest_state.state.solver_results.optimal_variables, ) return NIMBUSFinalizeResponse( @@ -455,16 +456,16 @@ def get_or_initialize( # No relevant state found, initialize a new one return initialize(request, context) + @router.post("/finalize") def finalize_nimbus( - request: NIMBUSFinalizeRequest, - context: Annotated[SessionContext, Depends(get_session_context)] + request: NIMBUSFinalizeRequest, context: Annotated[SessionContext, Depends(get_session_context)] ) -> NIMBUSFinalizeResponse: """An endpoint for finishing up the nimbus process. Args: request (NIMBUSFinalizeRequest): The request containing the final solution, etc. - context (Annotated[SessionContext, Depends): The session context. + context (Annotated[User, get_session_context): The current context. Raises: HTTPException @@ -492,7 +493,7 @@ def finalize_nimbus( final_state = NIMBUSFinalState( solution_origin_state_id=solution_state_id, solution_result_index=solution_index, - solver_results=actual_state.solver_results[solution_index] + solver_results=actual_state.solver_results[solution_index], ) state = StateDB.create( @@ -521,6 +522,7 @@ def finalize_nimbus( all_solutions=collect_all_solutions(user=user, problem_id=problem_db.id, session=db_session), ) + @router.post("/delete_save") def delete_save( request: NIMBUSDeleteSaveRequest, @@ -548,10 +550,7 @@ def delete_save( ).first() if to_be_deleted is None: - raise HTTPException( - detail="Unable to find a saved solution!", - status_code=status.HTTP_404_NOT_FOUND - ) + raise HTTPException(detail="Unable to find a saved solution!", status_code=status.HTTP_404_NOT_FOUND) db_session.delete(to_be_deleted) db_session.commit() @@ -565,10 +564,6 @@ def delete_save( if to_be_deleted is not None: raise HTTPException( - detail="Could not delete the saved solution!", - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR + detail="Could not delete the saved solution!", status_code=status.HTTP_500_INTERNAL_SERVER_ERROR ) - - return NIMBUSDeleteSaveResponse( - message="Save deleted." - ) + return NIMBUSDeleteSaveResponse(message="Save deleted.") diff --git a/desdeo/api/routers/problem.py b/desdeo/api/routers/problem.py index c3d17e698..9a85081d6 100644 --- a/desdeo/api/routers/problem.py +++ b/desdeo/api/routers/problem.py @@ -5,9 +5,8 @@ from fastapi import APIRouter, Depends, HTTPException, Request, UploadFile, status from fastapi.responses import JSONResponse -from sqlmodel import Session, select +from sqlmodel import select -from desdeo.api.db import get_session from desdeo.api.models import ( ForestProblemMetaData, ProblemDB, @@ -25,7 +24,8 @@ from desdeo.api.routers.user_authentication import get_current_user from desdeo.problem import Problem from desdeo.tools.utils import available_solvers -from .utils import get_session_context, get_session_context_base, SessionContext + +from .utils import SessionContext, get_session_context, get_session_context_base router = APIRouter(prefix="/problem") @@ -83,6 +83,7 @@ def get_problems_info(user: Annotated[User, Depends(get_current_user)]) -> list[ """ return user.problems + @router.post("/get") def get_problem( request: ProblemGetRequest, @@ -100,7 +101,6 @@ def get_problem( Returns: ProblemInfo: detailed information on the requested problem. """ - # db_session = context.db_session problem_db = context.problem_db # Ensure problem exists @@ -112,6 +112,7 @@ def get_problem( return problem_db + @router.post("/add") def add_problem( request: Annotated[Problem, Depends(parse_problem_json)], @@ -155,6 +156,7 @@ def add_problem( return problem_db + @router.post("/add_json") def add_problem_json( json_file: UploadFile, @@ -195,6 +197,7 @@ def add_problem_json( return problem_db + @router.post("/get_metadata") def get_metadata( request: ProblemMetaDataGetRequest, @@ -217,9 +220,7 @@ def get_metadata( """ db_session = context.db_session - problem_from_db = db_session.exec( - select(ProblemDB).where(ProblemDB.id == request.problem_id) - ).first() + problem_from_db = db_session.exec(select(ProblemDB).where(ProblemDB.id == request.problem_id)).first() if problem_from_db is None: raise HTTPException( @@ -233,17 +234,14 @@ def get_metadata( # no metadata define for the problem return [] # metadata is defined, try to find matching types based on request - return [ - metadata - for metadata in problem_metadata.all_metadata - if metadata.metadata_type == request.metadata_type - ] + return [metadata for metadata in problem_metadata.all_metadata if metadata.metadata_type == request.metadata_type] @router.get("/assign/solver", response_model=list[str]) def get_available_solvers() -> list[str]: """Return the list of available solver names.""" return list(available_solvers.keys()) + @router.post("/assign_solver") def select_solver( request: ProblemSelectSolverRequest, @@ -272,9 +270,7 @@ def select_solver( ) # Fetch problem - problem_db = db_session.exec( - select(ProblemDB).where(ProblemDB.id == request.problem_id) - ).first() + problem_db = db_session.exec(select(ProblemDB).where(ProblemDB.id == request.problem_id)).first() if problem_db is None: raise HTTPException( diff --git a/desdeo/api/routers/reference_point_method.py b/desdeo/api/routers/reference_point_method.py index 644034031..d1fec4c50 100644 --- a/desdeo/api/routers/reference_point_method.py +++ b/desdeo/api/routers/reference_point_method.py @@ -3,28 +3,23 @@ from typing import Annotated from fastapi import APIRouter, Depends, HTTPException, status -from sqlmodel import Session -from desdeo.api.db import get_session from desdeo.api.models import ( - InteractiveSessionDB, PreferenceDB, - ProblemDB, RPMSolveRequest, RPMState, StateDB, - User, ) from desdeo.api.routers.problem import check_solver -from desdeo.api.routers.user_authentication import get_current_user from desdeo.mcdm import rpm_solve_solutions from desdeo.problem import Problem from desdeo.tools import SolverResults -from .utils import fetch_interactive_session, fetch_parent_state, fetch_user_problem, get_session_context, SessionContext +from .utils import SessionContext, get_session_context router = APIRouter(prefix="/method/rpm") + @router.post("/solve") def solve_solutions( request: RPMSolveRequest, @@ -97,4 +92,4 @@ def solve_solutions( db_session.commit() db_session.refresh(state) - return rpm_state \ No newline at end of file + return rpm_state diff --git a/desdeo/api/routers/session.py b/desdeo/api/routers/session.py index 2421c0444..fb1c3ba33 100644 --- a/desdeo/api/routers/session.py +++ b/desdeo/api/routers/session.py @@ -14,15 +14,17 @@ User, ) from desdeo.api.routers.user_authentication import get_current_user -from desdeo.api.routers.utils import fetch_interactive_session, get_session_context_base, SessionContext +from desdeo.api.routers.utils import SessionContext, fetch_interactive_session, get_session_context_base router = APIRouter(prefix="/session") + @router.post("/new") def create_new_session( request: CreateSessionRequest, context: Annotated[SessionContext, Depends(get_session_context_base)], ) -> InteractiveSessionInfo: + """Creates a new interactive session.""" user = context.user db_session = context.db_session @@ -42,6 +44,7 @@ def create_new_session( return interactive_session + @router.get("/get/{session_id}") def get_session( session_id: int, diff --git a/desdeo/api/routers/utopia.py b/desdeo/api/routers/utopia.py index a474730ec..a1fb57753 100644 --- a/desdeo/api/routers/utopia.py +++ b/desdeo/api/routers/utopia.py @@ -4,9 +4,8 @@ from typing import Annotated from fastapi import APIRouter, Depends -from sqlmodel import Session, select +from sqlmodel import select -from desdeo.api.db import get_session from desdeo.api.models import ( ForestProblemMetaData, NIMBUSFinalState, @@ -14,7 +13,6 @@ NIMBUSSaveState, ProblemMetaDataDB, StateDB, - User, UtopiaRequest, UtopiaResponse, ) @@ -23,7 +21,6 @@ router = APIRouter(prefix="/utopia") - @router.post("/") def get_utopia_data( # noqa: C901 request: UtopiaRequest, @@ -32,9 +29,9 @@ def get_utopia_data( # noqa: C901 """Request and receive the Utopia map corresponding to the decision variables sent. Args: - request (UtopiaRequest): the set of decision variables and problem for which - the utopia forest map is requested for. - context (Annotated[SessionContext, Depends]): The session context. + request (UtopiaRequest): the set of decision variables and problem for which the utopia forest map is requested + for. + context (Annotated[SessionContext, Depends(get_session_context)]): the current session context Raises: HTTPException: diff --git a/pyproject.toml b/pyproject.toml index fb1ca9b78..417b08e6b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -147,6 +147,7 @@ lint.ignore = [ "COM812", # Enforcing trailing commas is too annoying. "PLR0911", # "too many return statements (>6)" "PLR0912", # "too many branches" + "PLR0913", # "too many function args" "PLR0915", # "too many statements (>50)" "TRY003", # allow long error messages ]