Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion desdeo/api/models/nimbus.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ class NIMBUSDeleteSaveRequest(SQLModel):

state_id : int = Field(description="The ID of the save state.")
solution_index: int = Field(description="The ID of the solution within the above state.")
problem_id: int = Field(description="The ID of the problem.")


class NIMBUSFinalizeRequest(SQLModel):
Expand Down Expand Up @@ -103,7 +104,7 @@ class NIMBUSDeleteSaveResponse(SQLModel):

response_type: str = "nimbus.delete_save"

message: str | None
message: str | None = None

class NIMBUSFinalizeResponse(SQLModel):
"""The response from NIMBUS finish endpoint."""
Expand Down
165 changes: 72 additions & 93 deletions desdeo/api/routers/emo.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
from desdeo.problem import Problem
from desdeo.tools.score_bands import SCOREBandsConfig, SCOREBandsResult, score_json

from .utils import fetch_interactive_session, fetch_user_problem
from .utils import fetch_interactive_session, fetch_user_problem, get_session_context, SessionContext

router = APIRouter(prefix="/method/emo", tags=["EMO"])

Expand Down Expand Up @@ -149,86 +149,64 @@ def get_templates() -> list[TemplateOptions]:
templates.append(template.template)
return templates


@router.post("/iterate")
def iterate(
request: EMOIterateRequest,
user: Annotated[User, Depends(get_current_user)],
session: Annotated[Session, Depends(get_session)],
context: Annotated[SessionContext, Depends(get_session_context)],
) -> EMOIterateResponse:
"""Starts the EMO method.

Args:
request (EMOSolveRequest): The request object containing parameters for the EMO method.
user (Annotated[User, Depends]): The current user.
session (Annotated[Session, Depends]): The database session.

Raises:
HTTPException: If the request is invalid or the EMO method fails.
"""Fetches results from a completed EMO method.

Returns:
IterateResponse: A response object containing a list of IDs to be used for websocket communication.
Also contains the StateDB id where the results will be stored.
Args: request (EMOIterateRequest): The request object containing parameters for fetching results.
context (Annotated[SessionContext, Depends]): The session context.
"""
interactive_session: InteractiveSessionDB | None = fetch_interactive_session(user, request, session)
# 1) Get context objects
db_session = context.db_session
interactive_session = context.interactive_session
parent_state = context.parent_state

# 2) Ensure problem exists
if context.problem_db is None:
raise HTTPException(status_code=404, detail="Problem not found")

problem_db = fetch_user_problem(user, request, session)
problem_db = context.problem_db
problem = Problem.from_problemdb(problem_db)

templates = request.template_options
# 3) Templates
templates = request.template_options or get_templates()

if templates is None:
templates = get_templates()
web_socket_ids = [
f"{template.algorithm_name.lower()}_{datetime.now().strftime('%Y%m%d%H%M%S%f')}"
for template in templates
]

web_socket_ids = []
for template in templates:
# Ensure unique names
web_socket_ids.append(f"{template.algorithm_name.lower()}_{datetime.now().strftime('%Y%m%d%H%M%S%f')}")
client_id = f"client_{datetime.now().strftime('%Y%m%d%H%M%S%f')}"
client_id = "client"

# Save request (incomplete and EAs have not finished running yet)

# Handle parent state
if request.parent_state_id is None:
parent_state = None
else:
statement = select(StateDB).where(StateDB.id == request.parent_state_id)
parent_state = session.exec(statement).first()

if parent_state is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Could not find state with id={request.parent_state_id}",
)

# 4) Create incomplete state
emo_iterate_state = EMOIterateState(
template_options=jsonable_encoder(templates),
preference_options=jsonable_encoder(request.preference_options),
)

incomplete_db_state = StateDB.create(
database_session=session,
database_session=db_session,
problem_id=problem_db.id,
session_id=interactive_session.id if interactive_session else None,
parent_id=parent_state.id if parent_state else None,
state=emo_iterate_state,
)

session.add(incomplete_db_state)
session.commit()
session.refresh(incomplete_db_state)
db_session.add(incomplete_db_state)
db_session.commit()
db_session.refresh(incomplete_db_state)

state_id = incomplete_db_state.id
if state_id is None:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to create a new state in the database.",
)
# Close db session
session.close()

# Spawn a new process to handle EMO method creation
# 5) Start process
Process(
target=_spawn_emo_process,
args=(
Expand Down Expand Up @@ -384,94 +362,92 @@ async def _ea_async( # noqa: PLR0913
await ws.send(f'{{"message": "Finished {websocket_id}", "send_to": "{client_id}"}}')
results_dict[websocket_id] = results


@router.post("/fetch")
async def fetch_results(
request: EMOFetchRequest,
user: Annotated[User, Depends(get_current_user)],
session: Annotated[Session, Depends(get_session)],
context: Annotated[SessionContext, Depends(get_session_context)],
) -> StreamingResponse:
"""Fetches results from a completed EMO method.

Args:
request (EMOFetchRequest): The request object containing parameters for fetching results.
user (Annotated[User, Depends]): The current user.
session (Annotated[Session, Depends]): The database session.
context (Annotated[SessionContext, Depends]): The session context.

Raises:
HTTPException: If the request is invalid or the EMO method has not completed.
Raises: HTTPException: If the request is invalid or the EMO method has not completed.

Returns:
StreamingResponse: A streaming response containing the results of the EMO method.
Returns: StreamingResponse: A streaming response containing the results of the EMO method.
"""
parent_state = request.parent_state_id
statement = select(StateDB).where(StateDB.id == parent_state)
state = session.exec(statement).first()
# Use context instead of manual fetch
state = context.parent_state

if state is None:
raise HTTPException(status_code=404, detail="Parent state not found.")

if not isinstance(state.state, EMOIterateState):
raise TypeError(f"State with id={parent_state} is not of type EMOIterateState.")
raise TypeError(f"State with id={request.parent_state_id} is not of type EMOIterateState.")

if not (state.state.objective_values and state.state.decision_variables):
raise ValueError(f"State does not contain results yet.")
raise ValueError("State does not contain results yet.")

# Convert objs: dict[str, list[float]] to objs: list[dict[str, float]]
raw_objs: dict[str, list[float]] = state.state.objective_values
n_solutions = len(next(iter(raw_objs.values())))
objs: list[dict[str, float]] = [{k: v[i] for k, v in raw_objs.items()} for i in range(n_solutions)]
objs: list[dict[str, float]] = [
{k: v[i] for k, v in raw_objs.items()} for i in range(n_solutions)
]

raw_decs: dict[str, list[float]] = state.state.decision_variables

decs: list[dict[str, float]] = [{k: v[i] for k, v in raw_decs.items()} for i in range(n_solutions)]

response: list[Solution] = []
decs: list[dict[str, float]] = [
{k: v[i] for k, v in raw_decs.items()} for i in range(n_solutions)
]

def result_stream():
for i in range(n_solutions):
item = {"solution_id": i, "objective_values": objs[i], "decision_variables": decs[i]}
item = {
"solution_id": i,
"objective_values": objs[i],
"decision_variables": decs[i],
}
yield json.dumps(item) + "\n"

return StreamingResponse(result_stream())


@router.post("/fetch_score")
async def fetch_score_bands(
request: EMOScoreRequest,
user: Annotated[User, Depends(get_current_user)],
session: Annotated[Session, Depends(get_session)],
context: Annotated[SessionContext, Depends(get_session_context)],
) -> EMOScoreResponse:
"""Fetches results from a completed EMO method.

Args:
request (EMOFetchRequest): The request object containing parameters for fetching results and of the SCORE bands
visualization.
user (Annotated[User, Depends]): The current user.
session (Annotated[Session, Depends]): The database session.
Args: request (EMOFetchRequest): The request object containing parameters for fetching
results and of the SCORE bands visualization.
context (Annotated[SessionContext, Depends]): The session context.

Raises:
HTTPException: If the request is invalid or the EMO method has not completed.

Returns:
SCOREBandsResult: The results of the SCORE bands visualization.
"""
if request.config is None:
score_config = SCOREBandsConfig()
else:
score_config = request.config
parent_state = request.parent_state_id
statement = select(StateDB).where(StateDB.id == parent_state)
state = session.exec(statement).first()
# Use context instead of manual fetch
state = context.parent_state
db_session = context.db_session
problem_db = context.problem_db

if state is None:
raise HTTPException(status_code=404, detail="Parent state not found.")

if not isinstance(state.state, EMOIterateState):
raise TypeError(f"State with id={parent_state} is not of type EMOIterateState.")
raise TypeError(f"State with id={request.parent_state_id} is not of type EMOIterateState.")

if not (state.state.objective_values and state.state.decision_variables):
raise ValueError(f"State does not contain results yet.")
raise ValueError("State does not contain results yet.")

if request.config is None:
score_config = SCOREBandsConfig()
else:
score_config = request.config

# Convert objs: dict[str, list[float]] to objs: list[dict[str, float]]
raw_objs: dict[str, list[float]] = state.state.objective_values
objs = pl.DataFrame(raw_objs)

Expand All @@ -482,16 +458,19 @@ async def fetch_score_bands(

score_state = EMOSCOREState(result=results.model_dump())

# Use the session + problem from context instead of request directly
score_db_state = StateDB.create(
database_session=session,
problem_id=request.problem_id,
session_id=request.session_id,
parent_id=parent_state,
database_session=db_session,
problem_id=problem_db.id,
session_id=state.session_id,
parent_id=state.id,
state=score_state,
)
session.add(score_db_state)
session.commit()
session.refresh(score_db_state)

db_session.add(score_db_state)
db_session.commit()
db_session.refresh(score_db_state)

state_id = score_db_state.id

return EMOScoreResponse(result=results, state_id=state_id)
Loading
Loading