diff --git a/desdeo/api/models/problem.py b/desdeo/api/models/problem.py index 53ba08a31..ff7a382ef 100644 --- a/desdeo/api/models/problem.py +++ b/desdeo/api/models/problem.py @@ -10,6 +10,7 @@ from sqlalchemy.types import JSON, String, TypeDecorator from sqlmodel import Column, Field, Relationship, SQLModel +from desdeo.api.models.representative_solution import RepresentativeSolutionSetBase from desdeo.problem.schema import ( Constant, Constraint, @@ -27,7 +28,6 @@ VariableDomainTypeEnum, VariableType, ) - from desdeo.tools.utils import available_solvers if TYPE_CHECKING: @@ -240,7 +240,7 @@ class ForestProblemMetaData(SQLModel, table=True): metadata_instance: "ProblemMetaDataDB" = Relationship(back_populates="forest_metadata") -class RepresentativeNonDominatedSolutions(SQLModel, table=True): +class RepresentativeNonDominatedSolutions(RepresentativeSolutionSetBase, SQLModel, table=True): """A problem metadata class to store representative solutions sets, i.e., non-dominated sets... A problem metadata class to store representative solutions sets, i.e., non-dominated sets that @@ -252,26 +252,12 @@ class RepresentativeNonDominatedSolutions(SQLModel, table=True): id: int | None = Field(primary_key=True, default=None) metadata_id: int | None = Field(foreign_key="problemmetadatadb.id", default=None) - metadata_type: str = "representative_non_dominated_solutions" - name: str = Field(description="The name of the representative set.") - description: str | None = Field(description="A description of the representative set. Optional.", default=None) - - solution_data: dict[str, list[float]] = Field( - sa_column=Column(JSON), - description="The non-dominated solutions. It is assumed that columns " - "exist for each variable and objective function. For functions, the " - "`_min` variant should be present, and any tensor variables should be " - "unrolled.", - ) + solution_data: dict[str, list[float]] = Field(sa_column=Column(JSON)) + ideal: dict[str, float] = Field(sa_column=Column(JSON)) + nadir: dict[str, float] = Field(sa_column=Column(JSON)) - ideal: dict[str, float] = Field( - sa_column=Column(JSON), description="The ideal objective function values of the representative set." - ) - nadir: dict[str, float] = Field( - sa_column=Column(JSON), description="The nadir objective function values of the representative set." - ) metadata_instance: "ProblemMetaDataDB" = Relationship(back_populates="representative_nd_metadata") diff --git a/desdeo/api/models/representative_solution.py b/desdeo/api/models/representative_solution.py new file mode 100644 index 000000000..8b643768c --- /dev/null +++ b/desdeo/api/models/representative_solution.py @@ -0,0 +1,32 @@ +from sqlmodel import SQLModel # noqa: D100 + + +class RepresentativeSolutionSetBase(SQLModel): + """Shared base model for representative solution sets.""" + + name: str + description: str | None = None + solution_data: dict[str, list[float]] + ideal: dict[str, float] + nadir: dict[str, float] + +class RepresentativeSolutionSetRequest(RepresentativeSolutionSetBase): + """Model of the request to the representative solution set.""" + + problem_id: int + +class RepresentativeSolutionSetInfo(SQLModel): + """Model of the representative solution set info.""" + + id: int + problem_id: int + name: str + description: str | None = None + ideal: dict[str, float] + nadir: dict[str, float] + +class RepresentativeSolutionSetFull( + RepresentativeSolutionSetInfo +): + """Model of the representative solution set full info.""" + solution_data: dict[str, list[float]] diff --git a/desdeo/api/models/request_models.py b/desdeo/api/models/request_models.py deleted file mode 100644 index 84ee74da1..000000000 --- a/desdeo/api/models/request_models.py +++ /dev/null @@ -1,12 +0,0 @@ -"""Requests models.""" -from pydantic import BaseModel - - -class RepresentativeSolutionSetRequest(BaseModel): - """Model of the request to the representative solution set.""" - problem_id: int - name: str - description: str | None = None - solution_data: dict[str, list[float]] - ideal: dict[str, float] - nadir: dict[str, float] diff --git a/desdeo/api/routers/problem.py b/desdeo/api/routers/problem.py index 438d575eb..9d85960f2 100644 --- a/desdeo/api/routers/problem.py +++ b/desdeo/api/routers/problem.py @@ -21,7 +21,11 @@ User, UserRole, ) -from desdeo.api.models.request_models import RepresentativeSolutionSetRequest +from desdeo.api.models.representative_solution import ( + RepresentativeSolutionSetFull, + RepresentativeSolutionSetInfo, + RepresentativeSolutionSetRequest, +) from desdeo.api.routers.user_authentication import get_current_user from desdeo.problem import Problem from desdeo.tools.utils import available_solvers @@ -320,7 +324,10 @@ def select_solver( return JSONResponse(content={"message": "OK"}, status_code=status.HTTP_200_OK) -@router.post("/add_representative_solution_set") +@router.post( + "/add_representative_solution_set", + response_model=RepresentativeSolutionSetInfo +) def add_representative_solution_set( payload: RepresentativeSolutionSetRequest, context: Annotated[SessionContext, Depends(get_session_context_without_request)], @@ -374,15 +381,20 @@ def add_representative_solution_set( db_session.commit() db_session.refresh(repr_metadata) - # Attach to problem metadata - problem_metadata.representative_nd_metadata.append(repr_metadata) - db_session.add(problem_metadata) - db_session.commit() - db_session.refresh(problem_metadata) - return {"message": "Representative solution set added successfully."} + return RepresentativeSolutionSetInfo( + id=repr_metadata.id, + problem_id=problem_db.id, + name=repr_metadata.name, + description=repr_metadata.description, + ideal=repr_metadata.ideal, + nadir=repr_metadata.nadir, + ) -@router.get("/all_representative_solution_sets/{problem_id}") +@router.get( + "/all_representative_solution_sets/{problem_id}", + response_model=list[RepresentativeSolutionSetInfo] +) def get_all_representative_solution_sets( problem_id: int, context: Annotated[SessionContext, Depends(get_session_context_without_request)], @@ -405,29 +417,26 @@ def get_all_representative_solution_sets( # Fetch metadata problem_metadata = problem_db.problem_metadata - if not problem_metadata or not problem_metadata.representative_nd_metadata: - return { - "problem_id": problem_id, - "representative_sets": [] - } + if not problem_metadata: + return [] # Build response - sets_meta = [ - { - "name": rep.name, - "description": rep.description, - "ideal": rep.ideal, - "nadir": rep.nadir - } + return [ + RepresentativeSolutionSetInfo( + id=rep.id, + problem_id=problem_id, + name=rep.name, + description=rep.description, + ideal=rep.ideal, + nadir=rep.nadir, + ) for rep in problem_metadata.representative_nd_metadata ] - return { - "problem_id": problem_id, - "representative_sets": sets_meta - } - -@router.get("/representative_solution_set/{set_id}") +@router.get( + "/representative_solution_set/{set_id}", + response_model=RepresentativeSolutionSetFull +) def get_representative_solution_set( set_id: int, context: Annotated[SessionContext, Depends(get_session_context_without_request)], @@ -445,16 +454,20 @@ def get_representative_solution_set( raise HTTPException(status_code=401, detail="Unauthorized user.") # Return all fields as a dict - return { - "id": repr_set.id, - "name": repr_set.name, - "description": repr_set.description, - "solution_data": repr_set.solution_data, - "ideal": repr_set.ideal, - "nadir": repr_set.nadir, - } - -@router.delete("/representative_solution_set/{set_id}") + return RepresentativeSolutionSetFull( + id=repr_set.id, + problem_id=repr_set.metadata_instance.problem_id, + name=repr_set.name, + description=repr_set.description, + solution_data=repr_set.solution_data, + ideal=repr_set.ideal, + nadir=repr_set.nadir, + ) + +@router.delete( + "/representative_solution_set/{set_id}", + status_code=status.HTTP_204_NO_CONTENT +) def delete_representative_solution_set( set_id: int, context: Annotated[SessionContext, Depends(get_session_context_without_request)], @@ -477,4 +490,3 @@ def delete_representative_solution_set( db_session.delete(repr_metadata) db_session.commit() - return {"detail": "Deleted successfully"} diff --git a/desdeo/api/tests/conftest.py b/desdeo/api/tests/conftest.py index 900536cea..1388b707b 100644 --- a/desdeo/api/tests/conftest.py +++ b/desdeo/api/tests/conftest.py @@ -20,7 +20,7 @@ UserRole, ) from desdeo.api.routers.user_authentication import get_password_hash -from desdeo.problem.testproblems import dtlz2, river_pollution_problem, dmitry_forest_problem_disc +from desdeo.problem.testproblems import dmitry_forest_problem_disc, dtlz2, river_pollution_problem @pytest.fixture(name="session_and_user", scope="function") diff --git a/desdeo/api/tests/test_problem_metadata.py b/desdeo/api/tests/test_problem_metadata.py index d11b4ec79..43a9b0748 100644 --- a/desdeo/api/tests/test_problem_metadata.py +++ b/desdeo/api/tests/test_problem_metadata.py @@ -2,6 +2,7 @@ from sqlmodel import select from desdeo.api.models import ProblemDB, ProblemMetaDataDB, RepresentativeNonDominatedSolutions +from desdeo.api.models.representative_solution import RepresentativeSolutionSetRequest from desdeo.problem.testproblems import dtlz2 from .conftest import login @@ -20,12 +21,11 @@ def test_add_representative_solution_set(client: TestClient, session_and_user: d session.commit() session.refresh(problem) - # Prepare solution set JSON - solution_set_payload = { - "problem_id": problem.id, - "name": "Test solutions", - "description": "Solutions for testing", - "solution_data": { + solution_set_model = RepresentativeSolutionSetRequest( + problem_id=problem.id, + name="Test solutions", + description="Solutions for testing", + solution_data={ "x_1": [1.1, 2.2, 3.3], "x_2": [-1.1, -2.2, -3.3], "f_1": [0.1, 0.5, 0.9], @@ -33,18 +33,22 @@ def test_add_representative_solution_set(client: TestClient, session_and_user: d "f_1_min": [], "f_2_min": [], }, - "ideal": {"f_1": 0.1, "f_2": -0.1}, - "nadir": {"f_1": 0.9, "f_2": 199.2}, - } + ideal={"f_1": 0.1, "f_2": -0.1}, + nadir={"f_1": 0.9, "f_2": 199.2}, + ) response = client.post( "/problem/add_representative_solution_set", headers={"Authorization": f"Bearer {access_token}"}, - json=solution_set_payload, + json=solution_set_model.model_dump(), # send as JSON ) assert response.status_code == 200 - assert response.json()["message"] == "Representative solution set added successfully." + data = response.json() + assert data["name"] == solution_set_model.name + assert data["description"] == solution_set_model.description + assert data["ideal"] == solution_set_model.ideal + assert data["nadir"] == solution_set_model.nadir # Verify DB statement = select(ProblemMetaDataDB).where(ProblemMetaDataDB.problem_id == problem.id) @@ -54,11 +58,11 @@ def test_add_representative_solution_set(client: TestClient, session_and_user: d repr_metadata = metadata.representative_nd_metadata[0] assert isinstance(repr_metadata, RepresentativeNonDominatedSolutions) - assert repr_metadata.name == solution_set_payload["name"] - assert repr_metadata.description == solution_set_payload["description"] - assert repr_metadata.solution_data == solution_set_payload["solution_data"] - assert repr_metadata.ideal == solution_set_payload["ideal"] - assert repr_metadata.nadir == solution_set_payload["nadir"] + assert repr_metadata.name == solution_set_model.name + assert repr_metadata.description == solution_set_model.description + assert repr_metadata.solution_data == solution_set_model.solution_data + assert repr_metadata.ideal == solution_set_model.ideal + assert repr_metadata.nadir == solution_set_model.nadir def test_get_all_representative_solution_sets(client: TestClient, session_and_user: dict): """Test that all representative solution sets for a problem can be fetched (meta-level).""" @@ -108,10 +112,10 @@ def test_get_all_representative_solution_sets(client: TestClient, session_and_us assert response.status_code == 200 data = response.json() - assert data["problem_id"] == problem.id - assert len(data["representative_sets"]) == 1 + assert isinstance(data, list) + assert len(data) == 1 - repr_meta = data["representative_sets"][0] + repr_meta = data[0] assert repr_meta["name"] == "Test Set GET" assert repr_meta["description"] == "Description GET" assert repr_meta["ideal"] == {"f_1": 0.1} @@ -190,10 +194,16 @@ def test_delete_representative_solution_set(client: TestClient, session_and_user session.commit() session.refresh(problem) + # Create metadata properly + problem_metadata = ProblemMetaDataDB(problem_id=problem.id, problem=problem) + session.add(problem_metadata) + session.commit() + session.refresh(problem_metadata) + # Add a representative solution set repr_metadata = RepresentativeNonDominatedSolutions( - metadata_id=ProblemMetaDataDB(problem_id=problem.id, problem=problem).id, - metadata_instance=ProblemMetaDataDB(problem_id=problem.id, problem=problem), + metadata_id=problem_metadata.id, + metadata_instance=problem_metadata, name="To be deleted", description="Test deletion", solution_data={"x": [1.0], "f": [0.0]}, @@ -209,8 +219,8 @@ def test_delete_representative_solution_set(client: TestClient, session_and_user f"/problem/representative_solution_set/{repr_metadata.id}", headers={"Authorization": f"Bearer {access_token}"}, ) - assert response.status_code == 200 - assert response.json()["detail"] == "Deleted successfully" + assert response.status_code == 204 + assert response.content == b"" # Verify DB deletion deleted_set = session.get(RepresentativeNonDominatedSolutions, repr_metadata.id)