Skip to content
35 changes: 16 additions & 19 deletions desdeo/api/routers/emo.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from desdeo.problem import Problem
from desdeo.tools.score_bands import SCOREBandsConfig, score_json

from .utils import SessionContext, get_session_context
from .utils import ContextField, SessionContext, SessionContextGuard

router = APIRouter(prefix="/method/emo", tags=["EMO"])

Expand Down Expand Up @@ -146,7 +146,10 @@ def get_templates() -> list[TemplateOptions]:
@router.post("/iterate")
def iterate(
request: EMOIterateRequest,
context: Annotated[SessionContext, Depends(get_session_context)],
context: Annotated[
SessionContext,
Depends(SessionContextGuard(require=[ContextField.PROBLEM]))
],
) -> EMOIterateResponse:
"""Fetches results from a completed EMO method.

Expand All @@ -159,16 +162,12 @@ def iterate(
"""
# Get context objects
db_session = context.db_session
interactive_session = context.interactive_session
parent_state = context.parent_state

# Ensure problem exists
if context.problem_db is None:
raise HTTPException(status_code=404, detail="Problem not found")

problem_db = context.problem_db
problem = Problem.from_problemdb(problem_db)

interactive_session = context.interactive_session
parent_state = context.parent_state

# Templates
templates = request.template_options or get_templates()

Expand Down Expand Up @@ -218,7 +217,6 @@ def iterate(

return EMOIterateResponse(method_ids=web_socket_ids, client_id=client_id, state_id=state_id)


def _spawn_emo_process(
problem: Problem,
templates: list[TemplateOptions],
Expand Down Expand Up @@ -363,7 +361,10 @@ async def _ea_async(
@router.post("/fetch")
async def fetch_results(
request: EMOFetchRequest,
context: Annotated[SessionContext, Depends(get_session_context)],
context: Annotated[
SessionContext,
Depends(SessionContextGuard(require=[ContextField.PARENT_STATE]))
],
) -> StreamingResponse:
"""Fetches results from a completed EMO method.

Expand All @@ -378,9 +379,6 @@ async def fetch_results(
# Use context instead of manual fetch
state = context.parent_state

if state is None:
raise HTTPException(status_code=404, detail="Parent state not found.")

if not isinstance(state.state, EMOIterateState):
raise TypeError(f"State with id={request.parent_state_id} is not of type EMOIterateState.")

Expand All @@ -406,11 +404,13 @@ def result_stream():

return StreamingResponse(result_stream())


@router.post("/fetch_score")
async def fetch_score_bands(
request: EMOScoreRequest,
context: Annotated[SessionContext, Depends(get_session_context)],
context: Annotated[
SessionContext,
Depends(SessionContextGuard(require=[ContextField.PROBLEM, ContextField.PARENT_STATE]))
],
) -> EMOScoreResponse:
"""Fetches results from a completed EMO method.

Expand All @@ -429,9 +429,6 @@ async def fetch_score_bands(
db_session = context.db_session
problem_db = context.problem_db

if parent_state is None:
raise HTTPException(status_code=404, detail="Parent state not found.")

if not isinstance(parent_state.state, EMOIterateState):
raise TypeError(f"State with id={request.parent_state_id} is not of type EMOIterateState.")

Expand Down
22 changes: 10 additions & 12 deletions desdeo/api/routers/enautilus.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,28 +28,22 @@
from desdeo.mcdm import ENautilusResult, enautilus_get_representative_solutions, enautilus_step
from desdeo.problem import Problem

from .utils import (
SessionContext,
get_session_context,
)
from .utils import ContextField, SessionContext, SessionContextGuard

router = APIRouter(prefix="/method/enautilus")


@router.post("/step")
def step(
request: ENautilusStepRequest,
context: Annotated[SessionContext, Depends(get_session_context)],
context: Annotated[SessionContext, Depends(SessionContextGuard(require=[ContextField.PROBLEM]))],
) -> ENautilusStepResponse:
"""Steps the E-NAUTILUS method."""
# user = context.user # not used here
db_session = context.db_session

problem_db = context.problem_db
problem = Problem.from_problemdb(problem_db)

interactive_session = context.interactive_session

parent_state = context.parent_state

representative_solutions = db_session.exec(
Expand Down Expand Up @@ -247,7 +241,10 @@ def get_representative(
@router.post("/finalize")
def finalize_enautilus(
request: ENautilusFinalizeRequest,
context: Annotated[SessionContext, Depends(get_session_context)],
context: Annotated[
SessionContext,
Depends(SessionContextGuard(require=[ContextField.PROBLEM, ContextField.PARENT_STATE])),
],
) -> ENautilusFinalizeResponse:
"""Finalize E-NAUTILUS by selecting the final solution.

Expand Down Expand Up @@ -365,8 +362,7 @@ def finalize_enautilus(

@router.get("/session_tree/{session_id}")
def get_session_tree(
session_id: int,
db_session: Annotated[Session, Depends(get_session)],
session_id: int, context: Annotated[SessionContext, Depends(SessionContextGuard())]
) -> ENautilusSessionTreeResponse:
"""Extract the full E-NAUTILUS decision tree for a session.

Expand All @@ -375,11 +371,13 @@ def get_session_tree(

Args:
session_id: The interactive session ID.
db_session: The database session.
context: The context of the query.

Returns:
ENautilusSessionTreeResponse with nodes, edges, root_ids, and decision_events.
"""
db_session = context.db_session

# Query step states
step_stmt = (
select(StateDB)
Expand Down
15 changes: 5 additions & 10 deletions desdeo/api/routers/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,18 @@
from desdeo.tools import SolverResults
from desdeo.tools.score_bands import calculate_axes_positions, cluster, order_dimensions

from .utils import SessionContext, get_session_context
from .utils import ContextField, SessionContext, SessionContextGuard

router = APIRouter(prefix="/method/generic")


@router.post("/intermediate")
def solve_intermediate(
request: IntermediateSolutionRequest,
context: Annotated[SessionContext, Depends(get_session_context)],
context: Annotated[
SessionContext,
Depends(SessionContextGuard(require=[ContextField.PROBLEM]))
],
) -> GenericIntermediateSolutionResponse:
"""Solve intermediate solutions between given two solutions.

Expand Down Expand Up @@ -82,13 +85,6 @@ def solve_intermediate(

var_and_obj_values_of_references.append((var_values, obj_values))

# Problem is now already loaded via context
if problem_db is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Problem with id={request.problem_id} could not be found.",
)

problem = Problem.from_problemdb(problem_db)

solver_results: list[SolverResults] = solve_intermediate_solutions(
Expand Down Expand Up @@ -145,7 +141,6 @@ def solve_intermediate(
],
)


@router.post("/score-bands-obj-data")
def calculate_score_bands_from_objective_data(
request: ScoreBandsRequest,
Expand Down
43 changes: 23 additions & 20 deletions desdeo/api/routers/nimbus.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
from desdeo.problem import Problem
from desdeo.tools import SolverResults

from .utils import SessionContext, get_session_context
from .utils import ContextField, SessionContext, SessionContextGuard

router = APIRouter(prefix="/method/nimbus")

Expand Down Expand Up @@ -108,7 +108,10 @@ def collect_all_solutions(user: User, problem_id: int, session: Session) -> list
@router.post("/solve")
def solve_solutions(
request: NIMBUSClassificationRequest,
context: Annotated[SessionContext, Depends(get_session_context)],
context: Annotated[
SessionContext,
Depends(SessionContextGuard(require=[ContextField.PROBLEM]))
],
) -> NIMBUSClassificationResponse:
"""Solve the problem using the NIMBUS method."""
db_session = context.db_session
Expand All @@ -117,12 +120,6 @@ def solve_solutions(
interactive_session = context.interactive_session
parent_state = context.parent_state

# Ensure problem exists
if problem_db is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail=f"Problem with id={request.problem_id} could not be found."
)

solver = check_solver(problem_db=problem_db)
problem = Problem.from_problemdb(problem_db)

Expand Down Expand Up @@ -181,7 +178,10 @@ def solve_solutions(
@router.post("/initialize")
def initialize(
request: NIMBUSInitializationRequest,
context: Annotated[SessionContext, Depends(get_session_context)],
context: Annotated[
SessionContext,
Depends(SessionContextGuard(require=[ContextField.PROBLEM]))
],
) -> NIMBUSInitializationResponse:
"""Initialize the problem for the NIMBUS method."""
db_session = context.db_session
Expand All @@ -190,11 +190,6 @@ def initialize(
interactive_session = context.interactive_session
parent_state = context.parent_state

if problem_db is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail=f"Problem with id={request.problem_id} could not be found."
)

solver = check_solver(problem_db=problem_db)
problem = Problem.from_problemdb(problem_db)

Expand Down Expand Up @@ -258,7 +253,8 @@ def initialize(

@router.post("/save")
def save(
request: NIMBUSSaveRequest, context: Annotated[SessionContext, Depends(get_session_context)]
request: NIMBUSSaveRequest,
context: Annotated[SessionContext, Depends(SessionContextGuard())]
) -> NIMBUSSaveResponse:
"""Save solutions."""
db_session = context.db_session
Expand Down Expand Up @@ -332,7 +328,10 @@ def save(
@router.post("/intermediate")
def solve_nimbus_intermediate(
request: IntermediateSolutionRequest,
context: Annotated[SessionContext, Depends(get_session_context)],
context: Annotated[
SessionContext,
Depends(SessionContextGuard(require=[ContextField.PROBLEM]))
],
) -> NIMBUSIntermediateSolutionResponse:
"""Solve intermediate solutions by forwarding the request to generic intermediate endpoint with context nimbus."""
db_session = context.db_session
Expand Down Expand Up @@ -362,7 +361,10 @@ def solve_nimbus_intermediate(
@router.post("/get-or-initialize")
def get_or_initialize(
request: NIMBUSInitializationRequest,
context: Annotated[SessionContext, Depends(get_session_context)],
context: Annotated[
SessionContext,
Depends(SessionContextGuard(require=[ContextField.PROBLEM]))
],
) -> (
NIMBUSInitializationResponse
| NIMBUSClassificationResponse
Expand Down Expand Up @@ -456,13 +458,14 @@ def get_or_initialize(

@router.post("/finalize")
def finalize_nimbus(
request: NIMBUSFinalizeRequest, context: Annotated[SessionContext, Depends(get_session_context)]
request: NIMBUSFinalizeRequest,
context: Annotated[SessionContext, Depends(SessionContextGuard(require=[ContextField.PROBLEM]))]
) -> NIMBUSFinalizeResponse:
"""An endpoint for finishing up the nimbus process.

Args:
request (NIMBUSFinalizeRequest): The request containing the final solution, etc.
context (Annotated[User, get_session_context): The current context.
context (Annotated[SessionContext, SessionContextGuard): The current context.

Raises:
HTTPException
Expand Down Expand Up @@ -523,7 +526,7 @@ def finalize_nimbus(
@router.post("/delete_save")
def delete_save(
request: NIMBUSDeleteSaveRequest,
context: Annotated[SessionContext, Depends(get_session_context)],
context: Annotated[SessionContext, Depends(SessionContextGuard())]
) -> NIMBUSDeleteSaveResponse:
"""Endpoint for deleting saved solutions.

Expand Down
Loading
Loading