Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 20 additions & 29 deletions desdeo/api/routers/emo.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,30 +15,24 @@
from fastapi import APIRouter, Depends, HTTPException, WebSocket, WebSocketDisconnect, status
from fastapi.encoders import jsonable_encoder
from fastapi.responses import StreamingResponse
from sqlmodel import Session, select
from sqlmodel import select
from websockets.asyncio.client import connect

from desdeo.api.db import get_session
from desdeo.api.models import InteractiveSessionDB, StateDB
from desdeo.api.models import StateDB
from desdeo.api.models.emo import (
EMOFetchRequest,
EMOFetchResponse,
EMOIterateRequest,
EMOIterateResponse,
EMOSaveRequest,
EMOScoreRequest,
EMOScoreResponse,
Solution,
)
from desdeo.api.models.problem import ProblemDB
from desdeo.api.models.state import EMOFetchState, EMOIterateState, EMOSaveState, EMOSCOREState
from desdeo.api.models.user import User
from desdeo.api.routers.user_authentication import get_current_user
from desdeo.api.models.state import EMOIterateState, EMOSCOREState
from desdeo.emo.options.templates import EMOOptions, PreferenceOptions, TemplateOptions, emo_constructor
from desdeo.problem import Problem
from desdeo.tools.score_bands import SCOREBandsConfig, SCOREBandsResult, score_json
from desdeo.tools.score_bands import SCOREBandsConfig, score_json

from .utils import fetch_interactive_session, fetch_user_problem, get_session_context, SessionContext
from .utils import SessionContext, get_session_context

router = APIRouter(prefix="/method/emo", tags=["EMO"])

Expand Down Expand Up @@ -113,7 +107,6 @@ async def websocket_endpoint(
try:
while True:
data = await websocket.receive_json()
# print(data)
if "send_to" in data:
try:
await ws_manager.send_private_message(data, data["send_to"])
Expand Down Expand Up @@ -149,6 +142,7 @@ def get_templates() -> list[TemplateOptions]:
templates.append(template.template)
return templates


@router.post("/iterate")
def iterate(
request: EMOIterateRequest,
Expand Down Expand Up @@ -179,8 +173,7 @@ def iterate(
templates = request.template_options or get_templates()

web_socket_ids = [
f"{template.algorithm_name.lower()}_{datetime.now().strftime('%Y%m%d%H%M%S%f')}"
for template in templates
f"{template.algorithm_name.lower()}_{datetime.now().strftime('%Y%m%d%H%M%S%f')}" for template in templates
]

client_id = f"client_{datetime.now().strftime('%Y%m%d%H%M%S%f')}"
Expand Down Expand Up @@ -301,7 +294,7 @@ def _spawn_emo_process( # noqa: PLR0913
session.close()


def _ea_sync( # noqa: PLR0913
def _ea_sync(
problem: Problem,
template: TemplateOptions,
preference_options: PreferenceOptions | None,
Expand Down Expand Up @@ -334,7 +327,7 @@ def _ea_sync( # noqa: PLR0913
)


async def _ea_async( # noqa: PLR0913
async def _ea_async(
problem: Problem,
websocket_id: str,
client_id: str,
Expand Down Expand Up @@ -366,6 +359,7 @@ async def _ea_async( # noqa: PLR0913
await ws.send(f'{{"message": "Finished {websocket_id}", "send_to": "{client_id}"}}')
results_dict[websocket_id] = results


@router.post("/fetch")
async def fetch_results(
request: EMOFetchRequest,
Expand Down Expand Up @@ -396,14 +390,10 @@ async def fetch_results(
# Convert objs: dict[str, list[float]] to objs: list[dict[str, float]]
raw_objs: dict[str, list[float]] = state.state.objective_values
n_solutions = len(next(iter(raw_objs.values())))
objs: list[dict[str, float]] = [
{k: v[i] for k, v in raw_objs.items()} for i in range(n_solutions)
]
objs: list[dict[str, float]] = [{k: v[i] for k, v in raw_objs.items()} for i in range(n_solutions)]

raw_decs: dict[str, list[float]] = state.state.decision_variables
decs: list[dict[str, float]] = [
{k: v[i] for k, v in raw_decs.items()} for i in range(n_solutions)
]
decs: list[dict[str, float]] = [{k: v[i] for k, v in raw_decs.items()} for i in range(n_solutions)]

def result_stream():
for i in range(n_solutions):
Expand All @@ -416,6 +406,7 @@ def result_stream():

return StreamingResponse(result_stream())


@router.post("/fetch_score")
async def fetch_score_bands(
request: EMOScoreRequest,
Expand All @@ -434,22 +425,22 @@ async def fetch_score_bands(
SCOREBandsResult: The results of the SCORE bands visualization.
"""
# Use context instead of manual fetch
state = context.parent_state
parent_state = context.parent_state
db_session = context.db_session
problem_db = context.problem_db

if state is None:
if parent_state is None:
raise HTTPException(status_code=404, detail="Parent state not found.")

if not isinstance(state.state, EMOIterateState):
if not isinstance(parent_state.state, EMOIterateState):
raise TypeError(f"State with id={request.parent_state_id} is not of type EMOIterateState.")

if not (state.state.objective_values and state.state.decision_variables):
if not (parent_state.state.objective_values and parent_state.state.decision_variables):
raise ValueError("State does not contain results yet.")

score_config = SCOREBandsConfig() if request.config is None else request.config

raw_objs: dict[str, list[float]] = state.state.objective_values
raw_objs: dict[str, list[float]] = parent_state.state.objective_values
objs = pl.DataFrame(raw_objs)

results = score_json(
Expand All @@ -463,8 +454,8 @@ async def fetch_score_bands(
score_db_state = StateDB.create(
database_session=db_session,
problem_id=problem_db.id,
session_id=state.session_id,
parent_id=state.id,
session_id=parent_state.session_id,
parent_id=parent_state.id,
state=score_state,
)

Expand Down
17 changes: 6 additions & 11 deletions desdeo/api/routers/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,30 +5,27 @@
import numpy as np
import pandas as pd
from fastapi import APIRouter, Depends, HTTPException, status
from sqlmodel import Session, select
from sqlmodel import select

from desdeo.api.db import get_session
from desdeo.api.models import (
InteractiveSessionDB,
IntermediateSolutionRequest,
IntermediateSolutionState,
ProblemDB,
ScoreBandsRequest,
ScoreBandsResponse,
SolutionReference,
StateDB,
User,
)
from desdeo.api.models.generic import GenericIntermediateSolutionResponse
from desdeo.api.routers.user_authentication import get_current_user
from desdeo.mcdm.nimbus import solve_intermediate_solutions
from desdeo.problem import Problem
from desdeo.tools import SolverResults
from desdeo.tools.score_bands import calculate_axes_positions, cluster, order_dimensions

from .utils import get_session_context, SessionContext
from .utils import SessionContext, get_session_context

router = APIRouter(prefix="/method/generic")


@router.post("/intermediate")
def solve_intermediate(
request: IntermediateSolutionRequest,
Expand All @@ -42,7 +39,6 @@ def solve_intermediate(
context (Annotated[SessionContext, Depends]): The session context.
"""
db_session = context.db_session
user = context.user # noqa: F841
problem_db = context.problem_db
interactive_session = context.interactive_session
parent_state = context.parent_state
Expand All @@ -64,9 +60,7 @@ def solve_intermediate(
reference_states = []

for solution_info in [request.reference_solution_1, request.reference_solution_2]:
solution_state = db_session.exec(
select(StateDB).where(StateDB.id == solution_info.state_id)
).first()
solution_state = db_session.exec(select(StateDB).where(StateDB.id == solution_info.state_id)).first()

if solution_state is None:
raise HTTPException(
Expand Down Expand Up @@ -165,6 +159,7 @@ def solve_intermediate(
],
)


@router.post("/score-bands-obj-data")
def calculate_score_bands_from_objective_data(
request: ScoreBandsRequest,
Expand Down
Loading