From bf595ae20040cfa1d89e65ed97d86e9db12b4a53 Mon Sep 17 00:00:00 2001 From: Giovanni Misitano Date: Mon, 26 Jan 2026 09:26:49 +0200 Subject: [PATCH] Web-API - Linter related changes, like removing deprecated imports and sorting them. - Other syntax changes. --- desdeo/api/routers/emo.py | 56 ++++++-------- desdeo/api/routers/generic.py | 17 ++-- desdeo/api/routers/nimbus.py | 81 +++++++++++--------- desdeo/api/routers/problem.py | 27 +++---- desdeo/api/routers/reference_point_method.py | 11 +-- desdeo/api/routers/session.py | 5 +- desdeo/api/routers/utopia.py | 17 ++-- pyproject.toml | 1 + 8 files changed, 99 insertions(+), 116 deletions(-) diff --git a/desdeo/api/routers/emo.py b/desdeo/api/routers/emo.py index e0f0c11a0..55d590076 100644 --- a/desdeo/api/routers/emo.py +++ b/desdeo/api/routers/emo.py @@ -15,30 +15,24 @@ from fastapi import APIRouter, Depends, HTTPException, WebSocket, WebSocketDisconnect, status from fastapi.encoders import jsonable_encoder from fastapi.responses import StreamingResponse -from sqlmodel import Session, select +from sqlmodel import select from websockets.asyncio.client import connect from desdeo.api.db import get_session -from desdeo.api.models import InteractiveSessionDB, StateDB +from desdeo.api.models import StateDB from desdeo.api.models.emo import ( EMOFetchRequest, - EMOFetchResponse, EMOIterateRequest, EMOIterateResponse, - EMOSaveRequest, EMOScoreRequest, EMOScoreResponse, - Solution, ) -from desdeo.api.models.problem import ProblemDB -from desdeo.api.models.state import EMOFetchState, EMOIterateState, EMOSaveState, EMOSCOREState -from desdeo.api.models.user import User -from desdeo.api.routers.user_authentication import get_current_user +from desdeo.api.models.state import EMOIterateState, EMOSCOREState from desdeo.emo.options.templates import EMOOptions, PreferenceOptions, TemplateOptions, emo_constructor from desdeo.problem import Problem -from desdeo.tools.score_bands import SCOREBandsConfig, SCOREBandsResult, score_json +from desdeo.tools.score_bands import SCOREBandsConfig, score_json -from .utils import fetch_interactive_session, fetch_user_problem, get_session_context, SessionContext +from .utils import SessionContext, get_session_context router = APIRouter(prefix="/method/emo", tags=["EMO"]) @@ -113,7 +107,6 @@ async def websocket_endpoint( try: while True: data = await websocket.receive_json() - print(data) if "send_to" in data: try: await ws_manager.send_private_message(data, data["send_to"]) @@ -149,6 +142,7 @@ def get_templates() -> list[TemplateOptions]: templates.append(template.template) return templates + @router.post("/iterate") def iterate( request: EMOIterateRequest, @@ -175,8 +169,7 @@ def iterate( templates = request.template_options or get_templates() web_socket_ids = [ - f"{template.algorithm_name.lower()}_{datetime.now().strftime('%Y%m%d%H%M%S%f')}" - for template in templates + f"{template.algorithm_name.lower()}_{datetime.now().strftime('%Y%m%d%H%M%S%f')}" for template in templates ] client_id = f"client_{datetime.now().strftime('%Y%m%d%H%M%S%f')}" @@ -297,7 +290,7 @@ def _spawn_emo_process( session.close() -def _ea_sync( # noqa: PLR0913 +def _ea_sync( problem: Problem, template: TemplateOptions, preference_options: PreferenceOptions | None, @@ -330,7 +323,7 @@ def _ea_sync( # noqa: PLR0913 ) -async def _ea_async( # noqa: PLR0913 +async def _ea_async( problem: Problem, websocket_id: str, client_id: str, @@ -362,6 +355,7 @@ async def _ea_async( # noqa: PLR0913 await ws.send(f'{{"message": "Finished {websocket_id}", "send_to": "{client_id}"}}') results_dict[websocket_id] = results + @router.post("/fetch") async def fetch_results( request: EMOFetchRequest, @@ -392,14 +386,10 @@ async def fetch_results( # Convert objs: dict[str, list[float]] to objs: list[dict[str, float]] raw_objs: dict[str, list[float]] = state.state.objective_values n_solutions = len(next(iter(raw_objs.values()))) - objs: list[dict[str, float]] = [ - {k: v[i] for k, v in raw_objs.items()} for i in range(n_solutions) - ] + objs: list[dict[str, float]] = [{k: v[i] for k, v in raw_objs.items()} for i in range(n_solutions)] raw_decs: dict[str, list[float]] = state.state.decision_variables - decs: list[dict[str, float]] = [ - {k: v[i] for k, v in raw_decs.items()} for i in range(n_solutions) - ] + decs: list[dict[str, float]] = [{k: v[i] for k, v in raw_decs.items()} for i in range(n_solutions)] def result_stream(): for i in range(n_solutions): @@ -412,6 +402,7 @@ def result_stream(): return StreamingResponse(result_stream()) + @router.post("/fetch_score") async def fetch_score_bands( request: EMOScoreRequest, @@ -419,7 +410,7 @@ async def fetch_score_bands( ) -> EMOScoreResponse: """Fetches results from a completed EMO method. - Args: request (EMOFetchRequest): The request object containing parameters for fetching + Args: request (EMOFetchRequest): The request object containing parameters for fetching results and of the SCORE bands visualization. context (Annotated[SessionContext, Depends]): The session context. @@ -430,25 +421,22 @@ async def fetch_score_bands( SCOREBandsResult: The results of the SCORE bands visualization. """ # Use context instead of manual fetch - state = context.parent_state + parent_state = context.parent_state db_session = context.db_session problem_db = context.problem_db - if state is None: + if parent_state is None: raise HTTPException(status_code=404, detail="Parent state not found.") - if not isinstance(state.state, EMOIterateState): + if not isinstance(parent_state.state, EMOIterateState): raise TypeError(f"State with id={request.parent_state_id} is not of type EMOIterateState.") - if not (state.state.objective_values and state.state.decision_variables): + if not (parent_state.state.objective_values and parent_state.state.decision_variables): raise ValueError("State does not contain results yet.") - if request.config is None: - score_config = SCOREBandsConfig() - else: - score_config = request.config + score_config = SCOREBandsConfig() if request.config is None else request.config - raw_objs: dict[str, list[float]] = state.state.objective_values + raw_objs: dict[str, list[float]] = parent_state.state.objective_values objs = pl.DataFrame(raw_objs) results = score_json( @@ -462,8 +450,8 @@ async def fetch_score_bands( score_db_state = StateDB.create( database_session=db_session, problem_id=problem_db.id, - session_id=state.session_id, - parent_id=state.id, + session_id=parent_state.session_id, + parent_id=parent_state.id, state=score_state, ) diff --git a/desdeo/api/routers/generic.py b/desdeo/api/routers/generic.py index bed234def..3b7afcbda 100644 --- a/desdeo/api/routers/generic.py +++ b/desdeo/api/routers/generic.py @@ -5,30 +5,27 @@ import numpy as np import pandas as pd from fastapi import APIRouter, Depends, HTTPException, status -from sqlmodel import Session, select +from sqlmodel import select -from desdeo.api.db import get_session from desdeo.api.models import ( - InteractiveSessionDB, IntermediateSolutionRequest, IntermediateSolutionState, - ProblemDB, ScoreBandsRequest, ScoreBandsResponse, SolutionReference, StateDB, - User, ) from desdeo.api.models.generic import GenericIntermediateSolutionResponse -from desdeo.api.routers.user_authentication import get_current_user from desdeo.mcdm.nimbus import solve_intermediate_solutions from desdeo.problem import Problem from desdeo.tools import SolverResults from desdeo.tools.score_bands import calculate_axes_positions, cluster, order_dimensions -from .utils import get_session_context, SessionContext +from .utils import SessionContext, get_session_context + router = APIRouter(prefix="/method/generic") + @router.post("/intermediate") def solve_intermediate( request: IntermediateSolutionRequest, @@ -42,7 +39,6 @@ def solve_intermediate( context (Annotated[SessionContext, Depends]): The session context. """ db_session = context.db_session - user = context.user problem_db = context.problem_db interactive_session = context.interactive_session parent_state = context.parent_state @@ -64,9 +60,7 @@ def solve_intermediate( reference_states = [] for solution_info in [request.reference_solution_1, request.reference_solution_2]: - solution_state = db_session.exec( - select(StateDB).where(StateDB.id == solution_info.state_id) - ).first() + solution_state = db_session.exec(select(StateDB).where(StateDB.id == solution_info.state_id)).first() if solution_state is None: raise HTTPException( @@ -165,6 +159,7 @@ def solve_intermediate( ], ) + @router.post("/score-bands-obj-data") def calculate_score_bands_from_objective_data( request: ScoreBandsRequest, diff --git a/desdeo/api/routers/nimbus.py b/desdeo/api/routers/nimbus.py index 63921a639..8d6c61566 100644 --- a/desdeo/api/routers/nimbus.py +++ b/desdeo/api/routers/nimbus.py @@ -6,9 +6,7 @@ from numpy import allclose from sqlmodel import Session, select -from desdeo.api.db import get_session from desdeo.api.models import ( - InteractiveSessionDB, IntermediateSolutionRequest, NIMBUSClassificationRequest, NIMBUSClassificationResponse, @@ -38,19 +36,20 @@ from desdeo.api.models.state import IntermediateSolutionState from desdeo.api.routers.generic import solve_intermediate from desdeo.api.routers.problem import check_solver -from desdeo.api.routers.user_authentication import get_current_user from desdeo.mcdm.nimbus import generate_starting_point, solve_sub_problems from desdeo.problem import Problem from desdeo.tools import SolverResults -from .utils import get_session_context, SessionContext +from .utils import SessionContext, get_session_context + router = APIRouter(prefix="/method/nimbus") + # helper for collecting solutions def filter_duplicates(solutions: list[SavedSolutionReference]) -> list[SavedSolutionReference]: """Filters out the duplicate values of objectives.""" # No solutions or only one solution. There can not be any duplicates. - if len(solutions) < 2: + if len(solutions) < 2: # noqa: PLR2004 return solutions # Get the objective values @@ -106,6 +105,7 @@ def collect_all_solutions(user: User, problem_id: int, session: Session) -> list return filter_duplicates(all_solutions) + @router.post("/solve") def solve_solutions( request: NIMBUSClassificationRequest, @@ -123,8 +123,7 @@ def solve_solutions( # ----------------------------- if problem_db is None: raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail=f"Problem with id={request.problem_id} could not be found." + status_code=status.HTTP_404_NOT_FOUND, detail=f"Problem with id={request.problem_id} could not be found." ) solver = check_solver(problem_db=problem_db) @@ -196,8 +195,7 @@ def initialize( if problem_db is None: raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail=f"Problem with id={request.problem_id} could not be found." + status_code=status.HTTP_404_NOT_FOUND, detail=f"Problem with id={request.problem_id} could not be found." ) solver = check_solver(problem_db=problem_db) @@ -213,8 +211,7 @@ def initialize( if state is None: raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail=f"StateDB with index {info.state_id} could not be found." + status_code=status.HTTP_404_NOT_FOUND, detail=f"StateDB with index {info.state_id} could not be found." ) starting_point = state.state.result_objective_values[info.solution_index] @@ -261,10 +258,10 @@ def initialize( all_solutions=all_solutions, ) + @router.post("/save") def save( - request: NIMBUSSaveRequest, - context: Annotated[SessionContext, Depends(get_session_context)] + request: NIMBUSSaveRequest, context: Annotated[SessionContext, Depends(get_session_context)] ) -> NIMBUSSaveResponse: """Save solutions.""" db_session = context.db_session @@ -272,7 +269,6 @@ def save( interactive_session = context.interactive_session parent_state = context.parent_state - # fetch parent state (same logic, but using context) if request.parent_state_id is None: parent_state = ( interactive_session.states[-1] @@ -284,8 +280,7 @@ def save( if parent_state is None: raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail=f"Could not find state with id={request.parent_state_id}" + status_code=status.HTTP_404_NOT_FOUND, detail=f"Could not find state with id={request.parent_state_id}" ) # Check for duplicate solutions and update names instead of saving duplicates @@ -336,6 +331,7 @@ def save( return NIMBUSSaveResponse(state_id=state.id) + @router.post("/intermediate") def solve_nimbus_intermediate( request: IntermediateSolutionRequest, @@ -363,13 +359,18 @@ def solve_nimbus_intermediate( all_solutions=all_solutions, ) + @router.post("/get-or-initialize") def get_or_initialize( request: NIMBUSInitializationRequest, context: Annotated[SessionContext, Depends(get_session_context)], -) -> NIMBUSInitializationResponse | NIMBUSClassificationResponse | \ - NIMBUSIntermediateSolutionResponse | NIMBUSFinalizeResponse: - +) -> ( + NIMBUSInitializationResponse + | NIMBUSClassificationResponse + | NIMBUSIntermediateSolutionResponse + | NIMBUSFinalizeResponse +): + """Get the latest NIMBUS state if it exists, or initialize a new one if it doesn't.""" db_session = context.db_session user = context.user interactive_session = context.interactive_session @@ -432,7 +433,7 @@ def get_or_initialize( solution_index=solution_index, state_id=origin_state_id, objective_values=latest_state.state.solver_results.optimal_objectives, - variable_values=latest_state.state.solver_results.optimal_variables + variable_values=latest_state.state.solver_results.optimal_variables, ) return NIMBUSFinalizeResponse( @@ -452,12 +453,23 @@ def get_or_initialize( # No relevant state found, initialize a new one return initialize(request, context) + @router.post("/finalize") def finalize_nimbus( - request: NIMBUSFinalizeRequest, - context: Annotated[SessionContext, Depends(get_session_context)] + request: NIMBUSFinalizeRequest, context: Annotated[SessionContext, Depends(get_session_context)] ) -> NIMBUSFinalizeResponse: + """An endpoint for finishing up the nimbus process. + + Args: + request (NIMBUSFinalizeRequest): The request containing the final solution, etc. + context (Annotated[User, get_session_context): The current context. + + Raises: + HTTPException + Returns: + NIMBUSFinalizeResponse: Response containing info on the final solution. + """ db_session = context.db_session user = context.user interactive_session = context.interactive_session @@ -469,16 +481,16 @@ def finalize_nimbus( parent_state = db_session.exec(select(StateDB).where(StateDB.id == request.parent_state_id)).first() if parent_state is None: raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail=f"Could not find state with id={request.parent_state_id}" + status_code=status.HTTP_404_NOT_FOUND, detail=f"Could not find state with id={request.parent_state_id}" ) # fetch problem - problem_db = db_session.exec(select(ProblemDB).where(ProblemDB.user_id == user.id, ProblemDB.id == request.problem_id)).first() + problem_db = db_session.exec( + select(ProblemDB).where(ProblemDB.user_id == user.id, ProblemDB.id == request.problem_id) + ).first() if problem_db is None: raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail=f"Problem with id={request.problem_id} could not be found." + status_code=status.HTTP_404_NOT_FOUND, detail=f"Problem with id={request.problem_id} could not be found." ) solution_state_id = request.solution_info.state_id @@ -495,7 +507,7 @@ def finalize_nimbus( final_state = NIMBUSFinalState( solution_origin_state_id=solution_state_id, solution_result_index=solution_index, - solver_results=actual_state.solver_results[solution_index] + solver_results=actual_state.solver_results[solution_index], ) state = StateDB.create( @@ -524,6 +536,7 @@ def finalize_nimbus( all_solutions=collect_all_solutions(user=user, problem_id=problem_db.id, session=db_session), ) + @router.post("/delete_save") def delete_save( request: NIMBUSDeleteSaveRequest, @@ -551,10 +564,7 @@ def delete_save( ).first() if to_be_deleted is None: - raise HTTPException( - detail="Unable to find a saved solution!", - status_code=status.HTTP_404_NOT_FOUND - ) + raise HTTPException(detail="Unable to find a saved solution!", status_code=status.HTTP_404_NOT_FOUND) db_session.delete(to_be_deleted) db_session.commit() @@ -568,10 +578,7 @@ def delete_save( if to_be_deleted is not None: raise HTTPException( - detail="Could not delete the saved solution!", - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR + detail="Could not delete the saved solution!", status_code=status.HTTP_500_INTERNAL_SERVER_ERROR ) - return NIMBUSDeleteSaveResponse( - message="Save deleted." - ) \ No newline at end of file + return NIMBUSDeleteSaveResponse(message="Save deleted.") diff --git a/desdeo/api/routers/problem.py b/desdeo/api/routers/problem.py index 1bf2a145f..0438b248d 100644 --- a/desdeo/api/routers/problem.py +++ b/desdeo/api/routers/problem.py @@ -5,9 +5,8 @@ from fastapi import APIRouter, Depends, HTTPException, Request, UploadFile, status from fastapi.responses import JSONResponse -from sqlmodel import Session, select +from sqlmodel import select -from desdeo.api.db import get_session from desdeo.api.models import ( ForestProblemMetaData, ProblemDB, @@ -25,7 +24,8 @@ from desdeo.api.routers.user_authentication import get_current_user from desdeo.problem import Problem from desdeo.tools.utils import available_solvers -from .utils import get_session_context, get_session_context_base, SessionContext + +from .utils import SessionContext, get_session_context, get_session_context_base router = APIRouter(prefix="/problem") @@ -84,6 +84,7 @@ def get_problems_info(user: Annotated[User, Depends(get_current_user)]) -> list[ """ return user.problems + @router.post("/get") def get_problem( request: ProblemGetRequest, @@ -97,7 +98,6 @@ def get_problem( Raises: HTTPException: could not find a problem with the given id. Returns: ProblemInfo: detailed information on the requested problem. """ - db_session = context.db_session problem_db = context.problem_db # ----------------------------- @@ -111,6 +111,7 @@ def get_problem( return problem_db + @router.post("/add") def add_problem( request: Annotated[Problem, Depends(parse_problem_json)], @@ -140,6 +141,7 @@ def add_problem( return problem_db + @router.post("/add_json") def add_problem_json( json_file: UploadFile, @@ -168,6 +170,7 @@ def add_problem_json( return problem_db + @router.post("/get_metadata") def get_metadata( request: ProblemMetaDataGetRequest, @@ -176,9 +179,7 @@ def get_metadata( """Fetch specific metadata for a specific problem.""" db_session = context.db_session - problem_from_db = db_session.exec( - select(ProblemDB).where(ProblemDB.id == request.problem_id) - ).first() + problem_from_db = db_session.exec(select(ProblemDB).where(ProblemDB.id == request.problem_id)).first() if problem_from_db is None: raise HTTPException( @@ -191,17 +192,15 @@ def get_metadata( if problem_metadata is None: return [] - return [ - metadata - for metadata in problem_metadata.all_metadata - if metadata.metadata_type == request.metadata_type - ] + return [metadata for metadata in problem_metadata.all_metadata if metadata.metadata_type == request.metadata_type] + @router.get("/assign/solver", response_model=list[str]) def get_available_solvers() -> list[str]: """Return the list of available solver names.""" return list(available_solvers.keys()) + @router.post("/assign_solver") def select_solver( request: ProblemSelectSolverRequest, @@ -219,9 +218,7 @@ def select_solver( ) # Fetch problem - problem_db = db_session.exec( - select(ProblemDB).where(ProblemDB.id == request.problem_id) - ).first() + problem_db = db_session.exec(select(ProblemDB).where(ProblemDB.id == request.problem_id)).first() if problem_db is None: raise HTTPException( diff --git a/desdeo/api/routers/reference_point_method.py b/desdeo/api/routers/reference_point_method.py index 68d2d49b1..7a515f095 100644 --- a/desdeo/api/routers/reference_point_method.py +++ b/desdeo/api/routers/reference_point_method.py @@ -3,28 +3,23 @@ from typing import Annotated from fastapi import APIRouter, Depends, HTTPException, status -from sqlmodel import Session -from desdeo.api.db import get_session from desdeo.api.models import ( - InteractiveSessionDB, PreferenceDB, - ProblemDB, RPMSolveRequest, RPMState, StateDB, - User, ) from desdeo.api.routers.problem import check_solver -from desdeo.api.routers.user_authentication import get_current_user from desdeo.mcdm import rpm_solve_solutions from desdeo.problem import Problem from desdeo.tools import SolverResults -from .utils import fetch_interactive_session, fetch_parent_state, fetch_user_problem, get_session_context, SessionContext +from .utils import SessionContext, get_session_context router = APIRouter(prefix="/method/rpm") + @router.post("/solve") def solve_solutions( request: RPMSolveRequest, @@ -86,4 +81,4 @@ def solve_solutions( db_session.commit() db_session.refresh(state) - return rpm_state \ No newline at end of file + return rpm_state diff --git a/desdeo/api/routers/session.py b/desdeo/api/routers/session.py index 2421c0444..fb1c3ba33 100644 --- a/desdeo/api/routers/session.py +++ b/desdeo/api/routers/session.py @@ -14,15 +14,17 @@ User, ) from desdeo.api.routers.user_authentication import get_current_user -from desdeo.api.routers.utils import fetch_interactive_session, get_session_context_base, SessionContext +from desdeo.api.routers.utils import SessionContext, fetch_interactive_session, get_session_context_base router = APIRouter(prefix="/session") + @router.post("/new") def create_new_session( request: CreateSessionRequest, context: Annotated[SessionContext, Depends(get_session_context_base)], ) -> InteractiveSessionInfo: + """Creates a new interactive session.""" user = context.user db_session = context.db_session @@ -42,6 +44,7 @@ def create_new_session( return interactive_session + @router.get("/get/{session_id}") def get_session( session_id: int, diff --git a/desdeo/api/routers/utopia.py b/desdeo/api/routers/utopia.py index 2abba3e51..9792c21a9 100644 --- a/desdeo/api/routers/utopia.py +++ b/desdeo/api/routers/utopia.py @@ -4,9 +4,8 @@ from typing import Annotated from fastapi import APIRouter, Depends -from sqlmodel import Session, select +from sqlmodel import select -from desdeo.api.db import get_session from desdeo.api.models import ( ForestProblemMetaData, NIMBUSFinalState, @@ -14,18 +13,16 @@ NIMBUSSaveState, ProblemMetaDataDB, StateDB, - User, UtopiaRequest, UtopiaResponse, ) -from desdeo.api.routers.user_authentication import get_current_user -from desdeo.api.routers.utils import get_session_context_base, SessionContext +from desdeo.api.routers.utils import SessionContext, get_session_context_base router = APIRouter(prefix="/utopia") @router.post("/") -def get_utopia_data( +def get_utopia_data( # noqa: C901 request: UtopiaRequest, context: Annotated[SessionContext, Depends(get_session_context_base)], ) -> UtopiaResponse: @@ -33,8 +30,8 @@ def get_utopia_data( Args: request (UtopiaRequest): the set of decision variables and problem for which the utopia forest map is requested - for. - context (Annotated[SessionContext, Depends(get_session_context_base)]) the current session context + for. + context (Annotated[SessionContext, Depends(get_session_context_base)]): the current session context Raises: HTTPException: Returns: @@ -106,7 +103,7 @@ def treatment_index(part: str) -> str: # The dict keys get converted to ints to strings when it's loaded from database try: treatments = forest_metadata.schedule_dict[key][str(decision_variables[key].index(1))] - except ValueError as e: + except ValueError: # if the optimization didn't choose any decision alternative, it's safe to assume # that nothing is being done at that forest stand treatments = forest_metadata.schedule_dict[key]["0"] @@ -228,4 +225,4 @@ def treatment_index(part: str) -> str: map_json=json.loads(forest_metadata.map_json), description=map_description, years=forest_metadata.years, - ) \ No newline at end of file + ) diff --git a/pyproject.toml b/pyproject.toml index 08edec38f..84bd2009b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -156,6 +156,7 @@ lint.ignore = [ "COM812", # Enforcing trailing commas is too annoying. "PLR0911", # "too many return statements (>6)" "PLR0912", # "too many branches" + "PLR0913", # "too many function args" "PLR0915", # "too many statements (>50)" "TRY003", # allow long error messages ]