Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 8 additions & 7 deletions desdeo/api/routers/emo.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ async def websocket_endpoint(
try:
while True:
data = await websocket.receive_json()
print(data)
# print(data)
if "send_to" in data:
try:
await ws_manager.send_private_message(data, data["send_to"])
Expand Down Expand Up @@ -158,6 +158,10 @@ def iterate(

Args: request (EMOIterateRequest): The request object containing parameters for fetching results.
context (Annotated[SessionContext, Depends]): The session context.

Raises: HTTPException: If the request is invalid or the EMO method fails.
Returns: IterateResponse: A response object containing a list of IDs to be used for websocket communication.
Also contains the StateDB id where the results will be stored.
"""
# 1) Get context objects
db_session = context.db_session
Expand Down Expand Up @@ -222,7 +226,7 @@ def iterate(
return EMOIterateResponse(method_ids=web_socket_ids, client_id=client_id, state_id=state_id)


def _spawn_emo_process(
def _spawn_emo_process( # noqa: PLR0913
problem: Problem,
templates: list[TemplateOptions],
preference_options: PreferenceOptions | None,
Expand Down Expand Up @@ -419,7 +423,7 @@ async def fetch_score_bands(
) -> EMOScoreResponse:
"""Fetches results from a completed EMO method.

Args: request (EMOFetchRequest): The request object containing parameters for fetching
Args: request (EMOFetchRequest): The request object containing parameters for fetching
results and of the SCORE bands visualization.
context (Annotated[SessionContext, Depends]): The session context.

Expand All @@ -443,10 +447,7 @@ async def fetch_score_bands(
if not (state.state.objective_values and state.state.decision_variables):
raise ValueError("State does not contain results yet.")

if request.config is None:
score_config = SCOREBandsConfig()
else:
score_config = request.config
score_config = SCOREBandsConfig() if request.config is None else request.config

raw_objs: dict[str, list[float]] = state.state.objective_values
objs = pl.DataFrame(raw_objs)
Expand Down
2 changes: 1 addition & 1 deletion desdeo/api/routers/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def solve_intermediate(
context (Annotated[SessionContext, Depends]): The session context.
"""
db_session = context.db_session
user = context.user
user = context.user # noqa: F841
problem_db = context.problem_db
interactive_session = context.interactive_session
parent_state = context.parent_state
Expand Down
41 changes: 19 additions & 22 deletions desdeo/api/routers/nimbus.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
def filter_duplicates(solutions: list[SavedSolutionReference]) -> list[SavedSolutionReference]:
"""Filters out the duplicate values of objectives."""
# No solutions or only one solution. There can not be any duplicates.
if len(solutions) < 2:
if len(solutions) < 2: # noqa: PLR2004
return solutions

# Get the objective values
Expand Down Expand Up @@ -351,7 +351,9 @@ def solve_nimbus_intermediate(
# Forward to generic endpoint
intermediate_response = solve_intermediate(request, context)

# Get saved solutions for this user and problem
saved_solutions = collect_saved_solutions(user, request.problem_id, db_session)
# Get all solutions including the newly generated intermediate ones
all_solutions = collect_all_solutions(user, request.problem_id, db_session)

return NIMBUSIntermediateSolutionResponse(
Expand All @@ -369,7 +371,7 @@ def get_or_initialize(
context: Annotated[SessionContext, Depends(get_session_context)],
) -> NIMBUSInitializationResponse | NIMBUSClassificationResponse | \
NIMBUSIntermediateSolutionResponse | NIMBUSFinalizeResponse:

"""Get the latest NIMBUS state if it exists, or initialize a new one if it doesn't."""
db_session = context.db_session
user = context.user
interactive_session = context.interactive_session
Expand All @@ -385,6 +387,7 @@ def get_or_initialize(
)
states = db_session.exec(statement).all()

# Find the latest relevant state (NIMBUS classification, initialization, or intermediate with NIMBUS context)
latest_state = None
for state in states:
if isinstance(state.state, (NIMBUSClassificationState | NIMBUSInitializationState | NIMBUSFinalState)) or (
Expand Down Expand Up @@ -441,7 +444,7 @@ def get_or_initialize(
saved_solutions=saved_solutions,
all_solutions=all_solutions,
)

# NIMBUSInitializationState
return NIMBUSInitializationResponse(
state_id=latest_state.id,
current_solutions=current_solutions,
Expand All @@ -457,29 +460,23 @@ def finalize_nimbus(
request: NIMBUSFinalizeRequest,
context: Annotated[SessionContext, Depends(get_session_context)]
) -> NIMBUSFinalizeResponse:
"""An endpoint for finishing up the nimbus process.

Args:
request (NIMBUSFinalizeRequest): The request containing the final solution, etc.
context (Annotated[SessionContext, Depends): The session context.

Raises:
HTTPException

Returns:
NIMBUSFinalizeResponse: Response containing info on the final solution.
"""
db_session = context.db_session
user = context.user
interactive_session = context.interactive_session
parent_state = context.parent_state

if request.parent_state_id is None:
parent_state = None
else:
parent_state = db_session.exec(select(StateDB).where(StateDB.id == request.parent_state_id)).first()
if parent_state is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Could not find state with id={request.parent_state_id}"
)

# fetch problem
problem_db = db_session.exec(select(ProblemDB).where(ProblemDB.user_id == user.id, ProblemDB.id == request.problem_id)).first()
if problem_db is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Problem with id={request.problem_id} could not be found."
)
problem_db = context.problem_db

solution_state_id = request.solution_info.state_id
solution_index = request.solution_info.solution_index
Expand Down Expand Up @@ -574,4 +571,4 @@ def delete_save(

return NIMBUSDeleteSaveResponse(
message="Save deleted."
)
)
86 changes: 70 additions & 16 deletions desdeo/api/routers/problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@

router = APIRouter(prefix="/problem")


def check_solver(problem_db: ProblemDB):
"""Check if a preferred solver is set in the metadata.

Expand Down Expand Up @@ -91,18 +90,20 @@ def get_problem(
) -> ProblemInfo:
"""Get the model of a specific problem.

Args: request (ProblemGetRequest): the request containing the problem's id `problem_id`.
user (Annotated[User, Depends): the current user.
session (Annotated[Session, Depends): the database session.
Raises: HTTPException: could not find a problem with the given id.
Returns: ProblemInfo: detailed information on the requested problem.
Args:
request (ProblemGetRequest): the request containing the problem's id `problem_id`.
context (Annotated[SessionContext, Depends): the session context.

Raises:
HTTPException: could not find a problem with the given id.

Returns:
ProblemInfo: detailed information on the requested problem.
"""
db_session = context.db_session
# db_session = context.db_session
problem_db = context.problem_db

# -----------------------------
# Ensure problem exists
# -----------------------------
if problem_db is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
Expand All @@ -116,7 +117,21 @@ def add_problem(
request: Annotated[Problem, Depends(parse_problem_json)],
context: Annotated[SessionContext, Depends(get_session_context_base)],
) -> ProblemInfo:
"""Add a newly defined problem to the database."""
"""Add a newly defined problem to the database.

Args:
request (Problem): the JSON representation of the problem.
context (Annotated[SessionContext, Depends): the session context.

Note:
Users with the role 'guest' may not add new problems.

Raises:
HTTPException: when any issue with defining the problem arises.

Returns:
ProblemInfo: the information about the problem added.
"""
user = context.user
db_session = context.db_session

Expand Down Expand Up @@ -145,7 +160,19 @@ def add_problem_json(
json_file: UploadFile,
context: Annotated[SessionContext, Depends(get_session_context_base)],
) -> ProblemInfo:
"""Adds a problem to the database based on its JSON definition."""
"""Adds a problem to the database based on its JSON definition.

Args:
json_file (UploadFile): a file in JSON format describing the problem.
context (Annotated[SessionContext, Depends): the session context.

Raises:
HTTPException: if the provided `json_file` is empty.
HTTPException: if the content in the provided `json_file` is not in JSON format.__annotations__

Returns:
ProblemInfo: a description of the added problem.
"""
user = context.user
db_session = context.db_session

Expand Down Expand Up @@ -173,7 +200,21 @@ def get_metadata(
request: ProblemMetaDataGetRequest,
context: Annotated[SessionContext, Depends(get_session_context)],
) -> list[ForestProblemMetaData | RepresentativeNonDominatedSolutions | SolverSelectionMetadata]:
"""Fetch specific metadata for a specific problem."""
"""Fetch specific metadata for a specific problem.

Fetch specific metadata for a specific problem. See all the possible
metadata types from DESDEO/desdeo/api/models/problem.py Problem Metadata
section.

Args:
request (MetaDataGetRequest): the requested metadata type.
context (Annotated[SessionContext, Depends]): the session context.

Returns:
list[ForestProblemMetadata | RepresentativeNonDominatedSolutions]: list containing all the metadata
defined for the problem with the requested metadata type. If no match is found,
returns an empty list.
"""
db_session = context.db_session

problem_from_db = db_session.exec(
Expand All @@ -189,8 +230,9 @@ def get_metadata(
problem_metadata = problem_from_db.problem_metadata

if problem_metadata is None:
# no metadata define for the problem
return []

# metadata is defined, try to find matching types based on request
return [
metadata
for metadata in problem_metadata.all_metadata
Expand All @@ -207,7 +249,18 @@ def select_solver(
request: ProblemSelectSolverRequest,
context: Annotated[SessionContext, Depends(get_session_context)],
) -> JSONResponse:
"""Assign a specific solver for a problem."""
"""Assign a specific solver for a problem.

Args:
request: ProblemSelectSolverRequest: The request containing problem id and string representation of the solver
context: Annotated[SessionContext, Depends(get_session)]: The session context.

Raises:
HTTPException: Unknown solver, unauthorized user

Returns:
JSONResponse: A simple confirmation.
"""
db_session = context.db_session
user = context.user

Expand All @@ -229,16 +282,17 @@ def select_solver(
status_code=status.HTTP_404_NOT_FOUND,
)

# Authorization
# Auth the user
if user.id != problem_db.user_id:
raise HTTPException(
detail="Unauthorized user!",
status_code=status.HTTP_401_UNAUTHORIZED,
)

# Ensure metadata exists
# All good, get on with it.
problem_metadata = problem_db.problem_metadata
if problem_metadata is None:
# There's no metadata for this problem! Create some.
problem_metadata = ProblemMetaDataDB(problem_id=problem_db.id, problem=problem_db)
db_session.add(problem_metadata)
db_session.commit()
Expand Down
17 changes: 14 additions & 3 deletions desdeo/api/routers/reference_point_method.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,17 @@ def solve_solutions(
request: RPMSolveRequest,
context: Annotated[SessionContext, Depends(get_session_context)],
) -> RPMState:
"""Runs an iteration of the reference point method."""
"""Runs an iteration of the reference point method.

Args:
request (RPMSolveRequest): a request with the needed information to run the method.
user (Annotated[User, Depends): the current user.
context (Annotated[SessionContext, Depends): the current session context.

Returns:
RPMState: a state with information on the results of iterating the reference point method
once.
"""
user = context.user
db_session = context.db_session
problem_db = context.problem_db
Expand All @@ -55,6 +65,7 @@ def solve_solutions(
request.solver_options,
)

# create DB preference
preference_db = PreferenceDB(
user_id=user.id,
problem_id=problem_db.id,
Expand All @@ -65,15 +76,15 @@ def solve_solutions(
db_session.commit()
db_session.refresh(preference_db)

# create RPM state (API model)
# create state and add to DB
rpm_state = RPMState(
scalarization_options=request.scalarization_options,
solver=request.solver,
solver_options=request.solver_options,
solver_results=solver_results,
)

# create DB state
# create DB state and add it to the DB
state = StateDB(
problem_id=problem_db.id,
preference_id=preference_db.id,
Expand Down
Loading
Loading