Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions desdeo/api/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
"ExtraFunctionDB",
"ForestProblemMetaData",
"GenericIntermediateSolutionResponse",
"GetSessionRequest",
"GNIMBUSOptimizationState",
"GNIMBUSVotingState",
"GNIMBUSEndState",
Expand Down Expand Up @@ -56,7 +55,6 @@
"PreferredRanges",
"ProblemDB",
"ProblemAddFromJSONRequest",
"ProblemGetRequest",
"ProblemInfo",
"ProblemInfoSmall",
"ProblemMetaDataDB",
Expand Down Expand Up @@ -237,7 +235,6 @@
ObjectiveDB,
ProblemAddFromJSONRequest,
ProblemDB,
ProblemGetRequest,
ProblemInfo,
ProblemInfoSmall,
ProblemMetaDataDB,
Expand All @@ -254,7 +251,6 @@
from .reference_point_method import RPMSolveRequest
from .session import (
CreateSessionRequest,
GetSessionRequest,
InteractiveSessionBase,
InteractiveSessionDB,
InteractiveSessionInfo,
Expand Down
6 changes: 0 additions & 6 deletions desdeo/api/models/problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,6 @@ class ProblemBase(SQLModel):
variable_domain: VariableDomainTypeEnum | None = Field()


class ProblemGetRequest(SQLModel):
"""Model to deal with problem fetching requests."""

problem_id: int


class ProblemSelectSolverRequest(SQLModel):
"""Model to request a specific solver for a problem."""

Expand Down
9 changes: 1 addition & 8 deletions desdeo/api/models/representative_solution.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,6 @@ class RepresentativeSolutionSetBase(SQLModel):
ideal: dict[str, float]
nadir: dict[str, float]

class RepresentativeSolutionSetRequest(RepresentativeSolutionSetBase):
"""Model of the request to the representative solution set."""

problem_id: int

class RepresentativeSolutionSetInfo(SQLModel):
"""Model of the representative solution set info."""

Expand All @@ -25,8 +20,6 @@ class RepresentativeSolutionSetInfo(SQLModel):
ideal: dict[str, float]
nadir: dict[str, float]

class RepresentativeSolutionSetFull(
RepresentativeSolutionSetInfo
):
class RepresentativeSolutionSetFull(RepresentativeSolutionSetInfo):
"""Model of the representative solution set full info."""
solution_data: dict[str, list[float]]
6 changes: 0 additions & 6 deletions desdeo/api/models/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,6 @@ class CreateSessionRequest(SQLModel):
info: str | None = Field(default=None)


class GetSessionRequest(SQLModel):
"""Model of the request to get a specific session."""

session_id: int = Field()


class InteractiveSessionBase(SQLModel):
"""The base model for representing interactive sessions."""

Expand Down
35 changes: 18 additions & 17 deletions desdeo/api/routers/problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from desdeo.api.models import (
ForestProblemMetaData,
ProblemDB,
ProblemGetRequest,
ProblemInfo,
ProblemInfoSmall,
ProblemMetaDataDB,
Expand All @@ -22,9 +21,9 @@
UserRole,
)
from desdeo.api.models.representative_solution import (
RepresentativeSolutionSetBase,
RepresentativeSolutionSetFull,
RepresentativeSolutionSetInfo,
RepresentativeSolutionSetRequest,
)
from desdeo.api.routers.user_authentication import get_current_user
from desdeo.problem import Problem
Expand Down Expand Up @@ -89,27 +88,28 @@ def get_problems_info(user: Annotated[User, Depends(get_current_user)]) -> list[
"""
return user.problems


@router.post("/get")
@router.get("/{problem_id}")
def get_problem(
request: ProblemGetRequest,
context: Annotated[SessionContext, Depends(SessionContextGuard(require=[ContextField.PROBLEM]))],
problem_id: int,
context: Annotated[
SessionContext,
Depends(SessionContextGuard(require=[ContextField.PROBLEM])),
],
) -> ProblemInfo:
"""Get the model of a specific problem.
"""Get a specific problem by id.

Args:
request (ProblemGetRequest): the request containing the problem's id `problem_id`.
context (Annotated[SessionContext, Depends): the session context.
problem_id (int): problem id.
context (Annotated[SessionContext, Depends): the session context.

Raises:
HTTPException: could not find a problem with the given id.
HTTPException: could not find a problem with the given id.

Returns:
ProblemInfo: detailed information on the requested problem.
ProblemInfo: detailed information on the requested problem.
"""
return context.problem_db


@router.post("/add")
def add_problem(
request: Annotated[Problem, Depends(parse_problem_json)],
Expand Down Expand Up @@ -299,15 +299,17 @@ def select_solver(
return JSONResponse(content={"message": "OK"}, status_code=status.HTTP_200_OK)


@router.post("/add_representative_solution_set")
@router.post("/{problem_id}/add_representative_solution_set")
def add_representative_solution_set(
request: RepresentativeSolutionSetRequest,
problem_id: int,
request: RepresentativeSolutionSetBase,
context: Annotated[SessionContext, Depends(SessionContextGuard(require=[ContextField.PROBLEM]))],
):
"""Add a new representative solution set as metadata to a problem.

Args:
request (RepresentativeSolutionSetRequest): The JSON body containing the
problem_id: int,
request (RepresentativeSolutionSetBase): The JSON body containing the
details of the representative solution set (name, description, solution data, ideal, nadir).
context (SessionContext): The session context providing the current user and database session.

Expand Down Expand Up @@ -349,8 +351,7 @@ def add_representative_solution_set(
nadir=repr_metadata.nadir,
)


@router.get("/all_representative_solution_sets/{problem_id}")
@router.get("/{problem_id}/all_representative_solution_sets")
def get_all_representative_solution_sets(
problem_id: int,
context: Annotated[SessionContext, Depends(SessionContextGuard(require=[]))],
Expand Down
8 changes: 2 additions & 6 deletions desdeo/api/routers/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from desdeo.api.db import get_session as get_db_session
from desdeo.api.models import (
CreateSessionRequest,
GetSessionRequest,
InteractiveSessionDB,
InteractiveSessionInfo,
User,
Expand Down Expand Up @@ -51,10 +50,9 @@ def get_session(
session: Annotated[Session, Depends(get_db_session)],
) -> InteractiveSessionInfo:
"""Return an interactive session with a current user."""
request = GetSessionRequest(session_id=session_id)
return fetch_interactive_session(
user=user,
request=request,
session_id=session_id,
session=session,
)

Expand Down Expand Up @@ -84,11 +82,9 @@ def delete_session(
session: Annotated[Session, Depends(get_db_session)],
) -> None:
"""Delete an interactive session and all its related states."""
request = GetSessionRequest(session_id=session_id)

interactive_session = fetch_interactive_session(
user=user,
request=request,
session_id=session_id,
session=session,
) # raises 404 if not found

Expand Down
47 changes: 39 additions & 8 deletions desdeo/api/routers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,21 +20,26 @@
StateDB,
User,
)
from desdeo.api.models.representative_solution import RepresentativeSolutionSetRequest
from desdeo.api.models.session import CreateSessionRequest
from desdeo.api.routers.user_authentication import get_current_user

RequestType = RPMSolveRequest | ENautilusStepRequest | RepresentativeSolutionSetRequest | CreateSessionRequest
RequestType = RPMSolveRequest | ENautilusStepRequest | CreateSessionRequest


def fetch_interactive_session(user: User, request: RequestType, session: Session) -> InteractiveSessionDB | None:
def fetch_interactive_session(
user: User,
session: Session,
request: RequestType | None = None,
session_id: int | None = None,
) -> InteractiveSessionDB | None:
"""Gets the desired instance of `InteractiveSessionDB`.

Args:
user (User): the user whose interactive sessions are to be queried.
request (RequestType): the request with possibly information on which interactive session to query.
session (Session): the database session (not to be confused with the interactive session) from
which the interactive session should be queried.
session_id (int): the id of a session

Note:
If no explicit `session_id` is given in `request`, this function will try to fetch the
Expand All @@ -47,23 +52,28 @@ def fetch_interactive_session(user: User, request: RequestType, session: Session
Returns:
InteractiveSessionDB | None: an interactive session DB model, or nothing.
"""
if request.session_id is not None:
# session_id param has highest priority
actual_session_id = session_id or (getattr(request, "session_id", None) if request else None)


if actual_session_id is not None:
# specific interactive session id is given, try using that
statement = select(InteractiveSessionDB).where(InteractiveSessionDB.id == request.session_id)
statement = select(InteractiveSessionDB).where(InteractiveSessionDB.id == actual_session_id)
interactive_session = session.exec(statement).first()

if interactive_session is None:
# Raise if explicitly requested interactive session cannot be found
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Could not find interactive session with id={request.session_id}.",
detail=f"Could not find interactive session with id={actual_session_id}.",
)
else:
# request.session_id is None
if user.active_session_id is None:
return None
# actual_session_id is None
# try to use active session instead

statement = select(InteractiveSessionDB).where(InteractiveSessionDB.id == user.active_session_id)

interactive_session = session.exec(statement).first()

# At this point interactive_session is either an instance of InteractiveSessionDB or None (which is fine)
Expand Down Expand Up @@ -184,6 +194,7 @@ def __call__(
user: Annotated[User, Depends(get_current_user)],
db_session: Annotated[Session, Depends(get_session)],
request: RequestType | None = None,
problem_id: int | None = None,
) -> SessionContext:
"""Call method for the SessionContextGuard class.

Expand All @@ -192,6 +203,7 @@ def __call__(
db_session (Annotated[Session, Depends): the current database session (dep).
request (RequestType | None, optional): request based on which the context is fetched.
Defaults to None.
problem_id (int): ID of the problem.

Returns:
SessionContext: the session context with the required fields specified in `self.require`.
Expand All @@ -205,6 +217,14 @@ def __call__(
if hasattr(request, "problem_id"):
problem_db = fetch_user_problem(user, request, db_session)

if problem_db is None and problem_id is not None:
class _ProblemOnly:
def __init__(self, problem_id: int):
self.problem_id = problem_id
self.session_id = None
self.parent_state_id = None
problem_db = fetch_user_problem(user, _ProblemOnly(problem_id), db_session)

if hasattr(request, "interactive_session_id") or hasattr(request, "problem_id"):
interactive_session = fetch_interactive_session(user, request, db_session)

Expand All @@ -215,6 +235,17 @@ def __call__(
db_session,
interactive_session=interactive_session,
)
elif problem_id is not None:
# Build a minimal fake request-like object
class _ProblemOnly:
def __init__(self, problem_id: int):
self.problem_id = problem_id
self.session_id = None
self.parent_state_id = None

pseudo_request = _ProblemOnly(problem_id)

problem_db = fetch_user_problem(user, pseudo_request, db_session)

context = SessionContext(
user=user,
Expand Down
Loading
Loading