Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions desdeo/api/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
emo,
enautilus,
generic,
nautilus_navigator,
nimbus,
problem,
reference_point_method,
Expand Down Expand Up @@ -38,6 +39,7 @@
app.include_router(gnimbus_routers.router)
app.include_router(enautilus.router)
app.include_router(gdm_score_bands_routers.router)
app.include_router(nautilus_navigator.router)

origins = AuthConfig.cors_origins

Expand Down
3 changes: 3 additions & 0 deletions desdeo/api/db.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
"""Database configuration file for the API."""

from sqlalchemy import event
from sqlalchemy.orm import declarative_base
from sqlmodel import Session, create_engine

from desdeo.api.config import DatabaseConfig, SettingsConfig

Base = declarative_base()

if SettingsConfig.debug:
# debug and development stuff

Expand Down
2 changes: 1 addition & 1 deletion desdeo/api/models/generic_states.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

from desdeo.problem import Tensor, VariableType

from .nautilus import NautilusNavigatorInitializationState, NautilusNavigatorNavigationState
from .nautilus_navigator import NautilusNavigatorInitializationState, NautilusNavigatorNavigationState
from .state import (
EMOFetchState,
EMOIterateState,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,13 @@ class NautilusNavigatorInitializationState(SQLModel, table=True):
__tablename__ = "nautilus_navigator_initialization_states"

# Primary key referencing the base State entry.
state_id: int | None = Field(
sa_column=Column(Integer, ForeignKey("states.id", ondelete="CASCADE"), primary_key=True)
# state_id: int | None = Field(
# sa_column=Column(Integer, ForeignKey("states.id", ondelete="CASCADE"), primary_key=True)
# )
id: int | None = Field(
default=None,
primary_key=True,
foreign_key="states.id",
)

class NautilusNavigatorNavigationState(SQLModel, table=True):
Expand Down Expand Up @@ -68,12 +73,20 @@ class NautilusNavigatorNavigationState(SQLModel, table=True):

__tablename__ = "nautilus_navigator_navigation_states"

# Primary key referencing the base State entry.
state_id: int | None = Field(
sa_column=Column(Integer, ForeignKey("states.id", ondelete="CASCADE"), primary_key=True)
# Primary key referencing the base State entry
id: int | None = Field(
default=None,
primary_key=True,
foreign_key="states.id",
)

# Foreign key referencing base State entry
# state_id: int | None = Field(
# sa_column=Column(Integer, ForeignKey("states.id", ondelete="CASCADE"), primary_key=True)
# )
steps_remaining: int
reference_point: dict[str, float] = Field(sa_column=Column(JSON))
bounds: dict[str, float] | None = Field(default=None, sa_column=Column(JSON))
previous_responses: list[dict] = Field(sa_column=Column(JSON))
navigator_results: list[dict] = Field(sa_column=Column(JSON))
parent_state_id: int | None = Field(default=None)
107 changes: 70 additions & 37 deletions desdeo/api/routers/nautilus_navigator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@
from pydantic import ValidationError
from sqlalchemy.orm import Session

from desdeo.api.db import get_db
from desdeo.api.db import get_session
from desdeo.api.db_models import Problem as ProblemInDB
from desdeo.api.models.nautilus import (
from desdeo.api.models import StateDB
from desdeo.api.models.nautilus_navigator import (
NautilusNavigatorInitializationState,
NautilusNavigatorNavigationState,
)
Expand Down Expand Up @@ -49,36 +50,49 @@ def map_response_to_step(response: NAUTILUS_Response) -> NautilusStep:
def initialize_navigator(
request: NautilusInitRequest,
user: Annotated[User, Depends(get_current_user)],
db: Annotated[Session, Depends(get_db)],
db: Annotated[Session, Depends(get_session)],
) -> NautilusInitialResponse:
"""Initialize NAUTILUS Navigator."""
# --- Validate problem ---
problem = db.query(ProblemInDB).filter(ProblemInDB.id == request.problem_id).first()
problem_db = db.query(ProblemInDB).filter(ProblemInDB.id == request.problem_id).first()

if problem is None:
if problem_db is None:
raise HTTPException(status_code=404, detail="Problem not found.")
if problem.owner != user.index and problem.owner is not None:
if problem_db.owner != user.id and problem_db.owner is not None:
raise HTTPException(status_code=403, detail="Unauthorized.")

# Remove forbidden fields manually for validation
raw_value = problem_db.value.copy()
raw_value.pop("is_convex_", None)
raw_value.pop("is_linear_", None)
raw_value.pop("is_twice_differentiable_", None)

try:
problem = Problem.model_validate(problem.value)
problem = Problem.model_validate(raw_value)
# problem = Problem.model_validate(problem.value)
except ValidationError:
raise HTTPException(status_code=500, detail="Invalid problem format.")
raise HTTPException(status_code=500, detail="Invalid problem format.") # noqa: B904

# --- Run algorithm ---
response = navigator_init(problem)

# --- Create base state (assuming you have a generic State table) ---
base_state = NautilusNavigatorInitializationState()
db.add(base_state)
# --- Create state properly via StateDB ---
substate = NautilusNavigatorInitializationState()

state_row = StateDB.create(
database_session=db,
problem_id=problem_db.id,
state=substate,
session_id=user.active_session_id, # or None if not used
)
db.commit()
db.refresh(base_state)
db.refresh(state_row)

# --- Map bounds ---
reachable_bounds = response.reachable_bounds or {}

return NautilusInitialResponse(
state_id=base_state.state_id,
state_id=state_row.base_state.id,
parent_state_id=None,
navigation_point=response.navigation_point,
lower_bounds=reachable_bounds.get("lower_bounds", {}),
Expand All @@ -91,46 +105,59 @@ def initialize_navigator(
def navigate_navigator(
request: NautilusNavigateRequest,
user: Annotated[User, Depends(get_current_user)],
db: Annotated[Session, Depends(get_db)],
db: Annotated[Session, Depends(get_session)],
) -> NautilusNavigateResponse:
"""Perform NAUTILUS navigation steps."""
# --- Validate problem ---
problem = db.query(ProblemInDB).filter(ProblemInDB.id == request.problem_id).first()
problem_db = db.query(ProblemInDB).filter(ProblemInDB.id == request.problem_id).first()

if problem is None:
if problem_db is None:
raise HTTPException(status_code=404, detail="Problem not found.")
if problem.owner != user.index and problem.owner is not None:
if problem_db.owner != user.id and problem_db.owner is not None:
raise HTTPException(status_code=403, detail="Unauthorized.")

raw_value = problem_db.value.copy()
raw_value.pop("is_convex_", None)
raw_value.pop("is_linear_", None)
raw_value.pop("is_twice_differentiable_", None)

try:
problem = Problem.model_validate(problem.value)
problem = Problem.model_validate(raw_value)
except ValidationError:
raise HTTPException(status_code=500, detail="Invalid problem format.")

raise HTTPException(status_code=500, detail="Invalid problem format.") # noqa: B904

# --- Determine parent state ---
parent_state = (
last_nav_state = (
db.query(NautilusNavigatorNavigationState)
.order_by(NautilusNavigatorNavigationState.state_id.desc())
.order_by(NautilusNavigatorNavigationState.id.desc())
.first()
)
parent_state_id = parent_state.state_id if parent_state else None
parent_state_id = last_nav_state.id if last_nav_state else None

# --- Extract previous responses ---
previous_responses: list[NAUTILUS_Response] = []
current = parent_state
current = last_nav_state
while current:
previous_responses = [
NAUTILUS_Response.model_validate(r) for r in current.navigator_results
] + previous_responses
if current.parent_state_id:
if getattr(current, "parent_state_id", None):
current = db.query(NautilusNavigatorNavigationState).filter(
NautilusNavigatorNavigationState.state_id == current.parent_state_id
NautilusNavigatorNavigationState.id == current.parent_state_id
).first()
else:
current = None

# --- Run algorithm ---
if not previous_responses:
previous_responses = [
NAUTILUS_Response(
step_number=0,
navigation_point=request.reference_point,
reachable_solution=None,
reference_point=request.reference_point,
bounds=request.bounds,
distance_to_front=0.0,
reachable_bounds={"lower_bounds": {}, "upper_bounds": {}},
)
]

try:
new_responses = navigator_all_steps(
problem=problem,
Expand All @@ -142,10 +169,9 @@ def navigate_navigator(
except IndexError as e:
raise HTTPException(status_code=400, detail="Bounds are too restrictive.") from e
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))
raise HTTPException(status_code=400, detail=str(e)) # noqa: B904

# --- Store state ---
navigation_state = NautilusNavigatorNavigationState(
substate = NautilusNavigatorNavigationState(
steps_remaining=request.steps_remaining,
reference_point=request.reference_point,
bounds=request.bounds,
Expand All @@ -154,15 +180,22 @@ def navigate_navigator(
parent_state_id=parent_state_id,
)

db.add(navigation_state)
state_row = StateDB.create(
database_session=db,
problem_id=problem_db.id,
state=substate,
session_id=user.active_session_id,
parent_id=parent_state_id,
)

db.commit()
db.refresh(navigation_state)
db.refresh(state_row)


# --- Map response ---
steps = [map_response_to_step(r) for r in new_responses]

return NautilusNavigateResponse(
state_id=navigation_state.state_id,
state_id=state_row.base_state.id,
parent_state_id=parent_state_id,
steps=steps,
)
Loading
Loading