From 7858301c76a96606f895b2b6679a79452af92634 Mon Sep 17 00:00:00 2001 From: Peter Bednarik Date: Thu, 26 Mar 2026 16:37:41 +0200 Subject: [PATCH] Nautilus Navigator endpoints /initialize and /navigate + tests.v. 1.0 --- desdeo/api/app.py | 2 + desdeo/api/db.py | 3 + desdeo/api/models/generic_states.py | 2 +- .../{nautilus.py => nautilus_navigator.py} | 23 ++- desdeo/api/routers/nautilus_navigator.py | 107 +++++++++----- desdeo/api/tests/test_routes.py | 132 ++++++++++++++++++ 6 files changed, 226 insertions(+), 43 deletions(-) rename desdeo/api/models/{nautilus.py => nautilus_navigator.py} (81%) diff --git a/desdeo/api/app.py b/desdeo/api/app.py index c66d135cb..d3cee285d 100644 --- a/desdeo/api/app.py +++ b/desdeo/api/app.py @@ -8,6 +8,7 @@ emo, enautilus, generic, + nautilus_navigator, nimbus, problem, reference_point_method, @@ -38,6 +39,7 @@ app.include_router(gnimbus_routers.router) app.include_router(enautilus.router) app.include_router(gdm_score_bands_routers.router) +app.include_router(nautilus_navigator.router) origins = AuthConfig.cors_origins diff --git a/desdeo/api/db.py b/desdeo/api/db.py index 0b1188993..f9922af1d 100644 --- a/desdeo/api/db.py +++ b/desdeo/api/db.py @@ -1,10 +1,13 @@ """Database configuration file for the API.""" from sqlalchemy import event +from sqlalchemy.orm import declarative_base from sqlmodel import Session, create_engine from desdeo.api.config import DatabaseConfig, SettingsConfig +Base = declarative_base() + if SettingsConfig.debug: # debug and development stuff diff --git a/desdeo/api/models/generic_states.py b/desdeo/api/models/generic_states.py index 45b38bb34..c136fcda7 100644 --- a/desdeo/api/models/generic_states.py +++ b/desdeo/api/models/generic_states.py @@ -17,7 +17,7 @@ from desdeo.problem import Tensor, VariableType -from .nautilus import NautilusNavigatorInitializationState, NautilusNavigatorNavigationState +from .nautilus_navigator import NautilusNavigatorInitializationState, NautilusNavigatorNavigationState from .state import ( EMOFetchState, EMOIterateState, diff --git a/desdeo/api/models/nautilus.py b/desdeo/api/models/nautilus_navigator.py similarity index 81% rename from desdeo/api/models/nautilus.py rename to desdeo/api/models/nautilus_navigator.py index 070b37e3b..012efcc9f 100644 --- a/desdeo/api/models/nautilus.py +++ b/desdeo/api/models/nautilus_navigator.py @@ -21,8 +21,13 @@ class NautilusNavigatorInitializationState(SQLModel, table=True): __tablename__ = "nautilus_navigator_initialization_states" # Primary key referencing the base State entry. - state_id: int | None = Field( - sa_column=Column(Integer, ForeignKey("states.id", ondelete="CASCADE"), primary_key=True) + # state_id: int | None = Field( + # sa_column=Column(Integer, ForeignKey("states.id", ondelete="CASCADE"), primary_key=True) + # ) + id: int | None = Field( + default=None, + primary_key=True, + foreign_key="states.id", ) class NautilusNavigatorNavigationState(SQLModel, table=True): @@ -68,12 +73,20 @@ class NautilusNavigatorNavigationState(SQLModel, table=True): __tablename__ = "nautilus_navigator_navigation_states" - # Primary key referencing the base State entry. - state_id: int | None = Field( - sa_column=Column(Integer, ForeignKey("states.id", ondelete="CASCADE"), primary_key=True) + # Primary key referencing the base State entry + id: int | None = Field( + default=None, + primary_key=True, + foreign_key="states.id", ) + + # Foreign key referencing base State entry + # state_id: int | None = Field( + # sa_column=Column(Integer, ForeignKey("states.id", ondelete="CASCADE"), primary_key=True) + # ) steps_remaining: int reference_point: dict[str, float] = Field(sa_column=Column(JSON)) bounds: dict[str, float] | None = Field(default=None, sa_column=Column(JSON)) previous_responses: list[dict] = Field(sa_column=Column(JSON)) navigator_results: list[dict] = Field(sa_column=Column(JSON)) + parent_state_id: int | None = Field(default=None) diff --git a/desdeo/api/routers/nautilus_navigator.py b/desdeo/api/routers/nautilus_navigator.py index 99661dba8..bda76ad16 100644 --- a/desdeo/api/routers/nautilus_navigator.py +++ b/desdeo/api/routers/nautilus_navigator.py @@ -6,9 +6,10 @@ from pydantic import ValidationError from sqlalchemy.orm import Session -from desdeo.api.db import get_db +from desdeo.api.db import get_session from desdeo.api.db_models import Problem as ProblemInDB -from desdeo.api.models.nautilus import ( +from desdeo.api.models import StateDB +from desdeo.api.models.nautilus_navigator import ( NautilusNavigatorInitializationState, NautilusNavigatorNavigationState, ) @@ -49,36 +50,49 @@ def map_response_to_step(response: NAUTILUS_Response) -> NautilusStep: def initialize_navigator( request: NautilusInitRequest, user: Annotated[User, Depends(get_current_user)], - db: Annotated[Session, Depends(get_db)], + db: Annotated[Session, Depends(get_session)], ) -> NautilusInitialResponse: """Initialize NAUTILUS Navigator.""" # --- Validate problem --- - problem = db.query(ProblemInDB).filter(ProblemInDB.id == request.problem_id).first() + problem_db = db.query(ProblemInDB).filter(ProblemInDB.id == request.problem_id).first() - if problem is None: + if problem_db is None: raise HTTPException(status_code=404, detail="Problem not found.") - if problem.owner != user.index and problem.owner is not None: + if problem_db.owner != user.id and problem_db.owner is not None: raise HTTPException(status_code=403, detail="Unauthorized.") + # Remove forbidden fields manually for validation + raw_value = problem_db.value.copy() + raw_value.pop("is_convex_", None) + raw_value.pop("is_linear_", None) + raw_value.pop("is_twice_differentiable_", None) + try: - problem = Problem.model_validate(problem.value) + problem = Problem.model_validate(raw_value) + # problem = Problem.model_validate(problem.value) except ValidationError: - raise HTTPException(status_code=500, detail="Invalid problem format.") + raise HTTPException(status_code=500, detail="Invalid problem format.") # noqa: B904 # --- Run algorithm --- response = navigator_init(problem) - # --- Create base state (assuming you have a generic State table) --- - base_state = NautilusNavigatorInitializationState() - db.add(base_state) + # --- Create state properly via StateDB --- + substate = NautilusNavigatorInitializationState() + + state_row = StateDB.create( + database_session=db, + problem_id=problem_db.id, + state=substate, + session_id=user.active_session_id, # or None if not used + ) db.commit() - db.refresh(base_state) + db.refresh(state_row) # --- Map bounds --- reachable_bounds = response.reachable_bounds or {} return NautilusInitialResponse( - state_id=base_state.state_id, + state_id=state_row.base_state.id, parent_state_id=None, navigation_point=response.navigation_point, lower_bounds=reachable_bounds.get("lower_bounds", {}), @@ -91,46 +105,59 @@ def initialize_navigator( def navigate_navigator( request: NautilusNavigateRequest, user: Annotated[User, Depends(get_current_user)], - db: Annotated[Session, Depends(get_db)], + db: Annotated[Session, Depends(get_session)], ) -> NautilusNavigateResponse: """Perform NAUTILUS navigation steps.""" - # --- Validate problem --- - problem = db.query(ProblemInDB).filter(ProblemInDB.id == request.problem_id).first() + problem_db = db.query(ProblemInDB).filter(ProblemInDB.id == request.problem_id).first() - if problem is None: + if problem_db is None: raise HTTPException(status_code=404, detail="Problem not found.") - if problem.owner != user.index and problem.owner is not None: + if problem_db.owner != user.id and problem_db.owner is not None: raise HTTPException(status_code=403, detail="Unauthorized.") + raw_value = problem_db.value.copy() + raw_value.pop("is_convex_", None) + raw_value.pop("is_linear_", None) + raw_value.pop("is_twice_differentiable_", None) + try: - problem = Problem.model_validate(problem.value) + problem = Problem.model_validate(raw_value) except ValidationError: - raise HTTPException(status_code=500, detail="Invalid problem format.") - + raise HTTPException(status_code=500, detail="Invalid problem format.") # noqa: B904 - # --- Determine parent state --- - parent_state = ( + last_nav_state = ( db.query(NautilusNavigatorNavigationState) - .order_by(NautilusNavigatorNavigationState.state_id.desc()) + .order_by(NautilusNavigatorNavigationState.id.desc()) .first() ) - parent_state_id = parent_state.state_id if parent_state else None + parent_state_id = last_nav_state.id if last_nav_state else None - # --- Extract previous responses --- previous_responses: list[NAUTILUS_Response] = [] - current = parent_state + current = last_nav_state while current: previous_responses = [ NAUTILUS_Response.model_validate(r) for r in current.navigator_results ] + previous_responses - if current.parent_state_id: + if getattr(current, "parent_state_id", None): current = db.query(NautilusNavigatorNavigationState).filter( - NautilusNavigatorNavigationState.state_id == current.parent_state_id + NautilusNavigatorNavigationState.id == current.parent_state_id ).first() else: current = None - # --- Run algorithm --- + if not previous_responses: + previous_responses = [ + NAUTILUS_Response( + step_number=0, + navigation_point=request.reference_point, + reachable_solution=None, + reference_point=request.reference_point, + bounds=request.bounds, + distance_to_front=0.0, + reachable_bounds={"lower_bounds": {}, "upper_bounds": {}}, + ) + ] + try: new_responses = navigator_all_steps( problem=problem, @@ -142,10 +169,9 @@ def navigate_navigator( except IndexError as e: raise HTTPException(status_code=400, detail="Bounds are too restrictive.") from e except Exception as e: - raise HTTPException(status_code=400, detail=str(e)) + raise HTTPException(status_code=400, detail=str(e)) # noqa: B904 - # --- Store state --- - navigation_state = NautilusNavigatorNavigationState( + substate = NautilusNavigatorNavigationState( steps_remaining=request.steps_remaining, reference_point=request.reference_point, bounds=request.bounds, @@ -154,15 +180,22 @@ def navigate_navigator( parent_state_id=parent_state_id, ) - db.add(navigation_state) + state_row = StateDB.create( + database_session=db, + problem_id=problem_db.id, + state=substate, + session_id=user.active_session_id, + parent_id=parent_state_id, + ) + db.commit() - db.refresh(navigation_state) + db.refresh(state_row) + - # --- Map response --- steps = [map_response_to_step(r) for r in new_responses] return NautilusNavigateResponse( - state_id=navigation_state.state_id, + state_id=state_row.base_state.id, parent_state_id=parent_state_id, steps=steps, ) diff --git a/desdeo/api/tests/test_routes.py b/desdeo/api/tests/test_routes.py index a0ba1e1c6..9842f1f5f 100644 --- a/desdeo/api/tests/test_routes.py +++ b/desdeo/api/tests/test_routes.py @@ -6,6 +6,7 @@ from fastapi import status from fastapi.testclient import TestClient +from desdeo.api.db_models import Problem as ProblemInDB from desdeo.api.models import ( CreateSessionRequest, EMOFetchRequest, @@ -42,6 +43,7 @@ User, UserPublic, ) +from desdeo.api.models.generic_states import State from desdeo.api.models.nimbus import NIMBUSInitializationResponse from desdeo.api.routers.user_authentication import create_access_token from desdeo.emo.options.algorithms import rvea_options @@ -55,6 +57,136 @@ from .conftest import get_json, login, post_file_multipart, post_json from .test_models import compare_models +# --- NAUTILUS Navigator endpoint tests --- + +def test_initialize_navigator(client: TestClient, session_and_user: dict): + """Test /nautilus/initialize using the existing test user.""" + access_token = login(client) + user = session_and_user["user"] + session = session_and_user["session"] + + ProblemInDB.metadata.create_all(bind=session.bind, tables=[ProblemInDB.__table__]) + + # Create a test problem + problem = dtlz2(3, 2).model_dump() # raw dict + # Remove fields not allowed by Problem model + problem.pop("is_convex_", None) + problem.pop("is_linear_", None) + problem.pop("is_twice_differentiable_", None) + + problem_db = ProblemInDB( + owner=user.id, + name="test_problem", + kind="continuous", + obj_kind="analytical", + value=problem + ) + session.add(problem_db) + session.commit() + session.refresh(problem_db) + + response = client.post( + "/nautilus/initialize", + json={"problem_id": problem_db.id}, + headers={"Authorization": f"Bearer {access_token}"}, + ) + + # Assertions + assert response.status_code == 200 + data = response.json() + assert "state_id" in data + assert "navigation_point" in data + assert "step_number" in data + assert "lower_bounds" in data + assert "upper_bounds" in data + + # Clean up + session.delete(problem_db) + session.commit() + +def test_navigate_navigator(client: TestClient, session_and_user: dict): + """Test performing a NAUTILUS navigation step using the updated StateDB-based endpoint.""" + access_token = login(client) + user = session_and_user["user"] + session = session_and_user["session"] + + ProblemInDB.metadata.create_all(bind=session.bind, tables=[ProblemInDB.__table__]) + + + # --- Create a REAL problem --- + problem_obj = dtlz2(3, 2) # 3 variables, 2 objectives + problem_dict = problem_obj.model_dump() + + # Remove forbidden fields for Problem model + problem_dict.pop("is_convex_", None) + problem_dict.pop("is_linear_", None) + problem_dict.pop("is_twice_differentiable_", None) + + problem_db = ProblemInDB( + owner=user.id, + name="test_problem", + kind="continuous", + obj_kind="analytical", + value=problem_dict + ) + + session.add(problem_db) + session.commit() + session.refresh(problem_db) + + # --- Initialize first --- + init_response = client.post( + "/nautilus/initialize", + json={"problem_id": problem_db.id}, + headers={"Authorization": f"Bearer {access_token}"}, + ) + + assert init_response.status_code == 200 + init_data = init_response.json() + assert "state_id" in init_data + + # --- Prepare reference point and bounds for all objectives --- + # Use the actual objective names from the problem + validated_problem = Problem.model_validate(problem_dict) + objective_names = [obj.name for obj in validated_problem.objectives] + + ref_point = {name: 0.5 for name in objective_names} + bounds = {name: 1.0 for name in objective_names} # upper limit for each objective + + # --- Navigate --- + navigate_payload = { + "problem_id": problem_db.id, + "steps_remaining": 1, + "reference_point": ref_point, + "bounds": bounds, + "go_back_step": 0, + } + + response = client.post( + "/nautilus/navigate", + json=navigate_payload, + headers={"Authorization": f"Bearer {access_token}"}, + ) + + assert response.status_code == 200 + + data = response.json() + assert "state_id" in data + assert "steps" in data + assert isinstance(data["steps"], list) + + if data["steps"]: + step = data["steps"][0] + assert "step_number" in step + assert "navigation_point" in step + assert "reachable_solution" in step + assert "lower_bounds" in step + assert "upper_bounds" in step + + # --- Clean up --- + session.delete(problem_db) + session.commit() + def test_user_login(client: TestClient): """Test that login works."""