Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 5 additions & 19 deletions desdeo/api/models/problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from sqlalchemy.types import JSON, String, TypeDecorator
from sqlmodel import Column, Field, Relationship, SQLModel

from desdeo.api.models.representative_solution import RepresentativeSolutionSetBase
from desdeo.problem.schema import (
Constant,
Constraint,
Expand All @@ -27,7 +28,6 @@
VariableDomainTypeEnum,
VariableType,
)

from desdeo.tools.utils import available_solvers

if TYPE_CHECKING:
Expand Down Expand Up @@ -240,7 +240,7 @@ class ForestProblemMetaData(SQLModel, table=True):
metadata_instance: "ProblemMetaDataDB" = Relationship(back_populates="forest_metadata")


class RepresentativeNonDominatedSolutions(SQLModel, table=True):
class RepresentativeNonDominatedSolutions(RepresentativeSolutionSetBase, SQLModel, table=True):
"""A problem metadata class to store representative solutions sets, i.e., non-dominated sets...

A problem metadata class to store representative solutions sets, i.e., non-dominated sets that
Expand All @@ -252,26 +252,12 @@ class RepresentativeNonDominatedSolutions(SQLModel, table=True):

id: int | None = Field(primary_key=True, default=None)
metadata_id: int | None = Field(foreign_key="problemmetadatadb.id", default=None)

metadata_type: str = "representative_non_dominated_solutions"

name: str = Field(description="The name of the representative set.")
description: str | None = Field(description="A description of the representative set. Optional.", default=None)

solution_data: dict[str, list[float]] = Field(
sa_column=Column(JSON),
description="The non-dominated solutions. It is assumed that columns "
"exist for each variable and objective function. For functions, the "
"`_min` variant should be present, and any tensor variables should be "
"unrolled.",
)
solution_data: dict[str, list[float]] = Field(sa_column=Column(JSON))
ideal: dict[str, float] = Field(sa_column=Column(JSON))
nadir: dict[str, float] = Field(sa_column=Column(JSON))

ideal: dict[str, float] = Field(
sa_column=Column(JSON), description="The ideal objective function values of the representative set."
)
nadir: dict[str, float] = Field(
sa_column=Column(JSON), description="The nadir objective function values of the representative set."
)

metadata_instance: "ProblemMetaDataDB" = Relationship(back_populates="representative_nd_metadata")

Expand Down
32 changes: 32 additions & 0 deletions desdeo/api/models/representative_solution.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from sqlmodel import SQLModel # noqa: D100


class RepresentativeSolutionSetBase(SQLModel):
"""Shared base model for representative solution sets."""

name: str
description: str | None = None
solution_data: dict[str, list[float]]
ideal: dict[str, float]
nadir: dict[str, float]

class RepresentativeSolutionSetRequest(RepresentativeSolutionSetBase):
"""Model of the request to the representative solution set."""

problem_id: int

class RepresentativeSolutionSetInfo(SQLModel):
"""Model of the representative solution set info."""

id: int
problem_id: int
name: str
description: str | None = None
ideal: dict[str, float]
nadir: dict[str, float]

class RepresentativeSolutionSetFull(
RepresentativeSolutionSetInfo
):
"""Model of the representative solution set full info."""
solution_data: dict[str, list[float]]
12 changes: 0 additions & 12 deletions desdeo/api/models/request_models.py

This file was deleted.

88 changes: 50 additions & 38 deletions desdeo/api/routers/problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,11 @@
User,
UserRole,
)
from desdeo.api.models.request_models import RepresentativeSolutionSetRequest
from desdeo.api.models.representative_solution import (
RepresentativeSolutionSetFull,
RepresentativeSolutionSetInfo,
RepresentativeSolutionSetRequest,
)
from desdeo.api.routers.user_authentication import get_current_user
from desdeo.problem import Problem
from desdeo.tools.utils import available_solvers
Expand Down Expand Up @@ -320,7 +324,10 @@ def select_solver(

return JSONResponse(content={"message": "OK"}, status_code=status.HTTP_200_OK)

@router.post("/add_representative_solution_set")
@router.post(
"/add_representative_solution_set",
response_model=RepresentativeSolutionSetInfo
)
def add_representative_solution_set(
payload: RepresentativeSolutionSetRequest,
context: Annotated[SessionContext, Depends(get_session_context_without_request)],
Expand Down Expand Up @@ -374,15 +381,20 @@ def add_representative_solution_set(
db_session.commit()
db_session.refresh(repr_metadata)

# Attach to problem metadata
problem_metadata.representative_nd_metadata.append(repr_metadata)
db_session.add(problem_metadata)
db_session.commit()
db_session.refresh(problem_metadata)

return {"message": "Representative solution set added successfully."}
return RepresentativeSolutionSetInfo(
id=repr_metadata.id,
problem_id=problem_db.id,
name=repr_metadata.name,
description=repr_metadata.description,
ideal=repr_metadata.ideal,
nadir=repr_metadata.nadir,
)

@router.get("/all_representative_solution_sets/{problem_id}")
@router.get(
"/all_representative_solution_sets/{problem_id}",
response_model=list[RepresentativeSolutionSetInfo]
)
def get_all_representative_solution_sets(
problem_id: int,
context: Annotated[SessionContext, Depends(get_session_context_without_request)],
Expand All @@ -405,29 +417,26 @@ def get_all_representative_solution_sets(

# Fetch metadata
problem_metadata = problem_db.problem_metadata
if not problem_metadata or not problem_metadata.representative_nd_metadata:
return {
"problem_id": problem_id,
"representative_sets": []
}
if not problem_metadata:
return []

# Build response
sets_meta = [
{
"name": rep.name,
"description": rep.description,
"ideal": rep.ideal,
"nadir": rep.nadir
}
return [
RepresentativeSolutionSetInfo(
id=rep.id,
problem_id=problem_id,
name=rep.name,
description=rep.description,
ideal=rep.ideal,
nadir=rep.nadir,
)
for rep in problem_metadata.representative_nd_metadata
]

return {
"problem_id": problem_id,
"representative_sets": sets_meta
}

@router.get("/representative_solution_set/{set_id}")
@router.get(
"/representative_solution_set/{set_id}",
response_model=RepresentativeSolutionSetFull
)
def get_representative_solution_set(
set_id: int,
context: Annotated[SessionContext, Depends(get_session_context_without_request)],
Expand All @@ -445,16 +454,20 @@ def get_representative_solution_set(
raise HTTPException(status_code=401, detail="Unauthorized user.")

# Return all fields as a dict
return {
"id": repr_set.id,
"name": repr_set.name,
"description": repr_set.description,
"solution_data": repr_set.solution_data,
"ideal": repr_set.ideal,
"nadir": repr_set.nadir,
}

@router.delete("/representative_solution_set/{set_id}")
return RepresentativeSolutionSetFull(
id=repr_set.id,
problem_id=repr_set.metadata_instance.problem_id,
name=repr_set.name,
description=repr_set.description,
solution_data=repr_set.solution_data,
ideal=repr_set.ideal,
nadir=repr_set.nadir,
)

@router.delete(
"/representative_solution_set/{set_id}",
status_code=status.HTTP_204_NO_CONTENT
)
def delete_representative_solution_set(
set_id: int,
context: Annotated[SessionContext, Depends(get_session_context_without_request)],
Expand All @@ -477,4 +490,3 @@ def delete_representative_solution_set(
db_session.delete(repr_metadata)
db_session.commit()

return {"detail": "Deleted successfully"}
2 changes: 1 addition & 1 deletion desdeo/api/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
UserRole,
)
from desdeo.api.routers.user_authentication import get_password_hash
from desdeo.problem.testproblems import dtlz2, river_pollution_problem, dmitry_forest_problem_disc
from desdeo.problem.testproblems import dmitry_forest_problem_disc, dtlz2, river_pollution_problem


@pytest.fixture(name="session_and_user", scope="function")
Expand Down
56 changes: 33 additions & 23 deletions desdeo/api/tests/test_problem_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from sqlmodel import select

from desdeo.api.models import ProblemDB, ProblemMetaDataDB, RepresentativeNonDominatedSolutions
from desdeo.api.models.representative_solution import RepresentativeSolutionSetRequest
from desdeo.problem.testproblems import dtlz2

from .conftest import login
Expand All @@ -20,31 +21,34 @@ def test_add_representative_solution_set(client: TestClient, session_and_user: d
session.commit()
session.refresh(problem)

# Prepare solution set JSON
solution_set_payload = {
"problem_id": problem.id,
"name": "Test solutions",
"description": "Solutions for testing",
"solution_data": {
solution_set_model = RepresentativeSolutionSetRequest(
problem_id=problem.id,
name="Test solutions",
description="Solutions for testing",
solution_data={
"x_1": [1.1, 2.2, 3.3],
"x_2": [-1.1, -2.2, -3.3],
"f_1": [0.1, 0.5, 0.9],
"f_2": [-0.1, 0.2, 199.2],
"f_1_min": [],
"f_2_min": [],
},
"ideal": {"f_1": 0.1, "f_2": -0.1},
"nadir": {"f_1": 0.9, "f_2": 199.2},
}
ideal={"f_1": 0.1, "f_2": -0.1},
nadir={"f_1": 0.9, "f_2": 199.2},
)

response = client.post(
"/problem/add_representative_solution_set",
headers={"Authorization": f"Bearer {access_token}"},
json=solution_set_payload,
json=solution_set_model.model_dump(), # send as JSON
)

assert response.status_code == 200
assert response.json()["message"] == "Representative solution set added successfully."
data = response.json()
assert data["name"] == solution_set_model.name
assert data["description"] == solution_set_model.description
assert data["ideal"] == solution_set_model.ideal
assert data["nadir"] == solution_set_model.nadir

# Verify DB
statement = select(ProblemMetaDataDB).where(ProblemMetaDataDB.problem_id == problem.id)
Expand All @@ -54,11 +58,11 @@ def test_add_representative_solution_set(client: TestClient, session_and_user: d

repr_metadata = metadata.representative_nd_metadata[0]
assert isinstance(repr_metadata, RepresentativeNonDominatedSolutions)
assert repr_metadata.name == solution_set_payload["name"]
assert repr_metadata.description == solution_set_payload["description"]
assert repr_metadata.solution_data == solution_set_payload["solution_data"]
assert repr_metadata.ideal == solution_set_payload["ideal"]
assert repr_metadata.nadir == solution_set_payload["nadir"]
assert repr_metadata.name == solution_set_model.name
assert repr_metadata.description == solution_set_model.description
assert repr_metadata.solution_data == solution_set_model.solution_data
assert repr_metadata.ideal == solution_set_model.ideal
assert repr_metadata.nadir == solution_set_model.nadir

def test_get_all_representative_solution_sets(client: TestClient, session_and_user: dict):
"""Test that all representative solution sets for a problem can be fetched (meta-level)."""
Expand Down Expand Up @@ -108,10 +112,10 @@ def test_get_all_representative_solution_sets(client: TestClient, session_and_us
assert response.status_code == 200

data = response.json()
assert data["problem_id"] == problem.id
assert len(data["representative_sets"]) == 1
assert isinstance(data, list)
assert len(data) == 1

repr_meta = data["representative_sets"][0]
repr_meta = data[0]
assert repr_meta["name"] == "Test Set GET"
assert repr_meta["description"] == "Description GET"
assert repr_meta["ideal"] == {"f_1": 0.1}
Expand Down Expand Up @@ -190,10 +194,16 @@ def test_delete_representative_solution_set(client: TestClient, session_and_user
session.commit()
session.refresh(problem)

# Create metadata properly
problem_metadata = ProblemMetaDataDB(problem_id=problem.id, problem=problem)
session.add(problem_metadata)
session.commit()
session.refresh(problem_metadata)

# Add a representative solution set
repr_metadata = RepresentativeNonDominatedSolutions(
metadata_id=ProblemMetaDataDB(problem_id=problem.id, problem=problem).id,
metadata_instance=ProblemMetaDataDB(problem_id=problem.id, problem=problem),
metadata_id=problem_metadata.id,
metadata_instance=problem_metadata,
name="To be deleted",
description="Test deletion",
solution_data={"x": [1.0], "f": [0.0]},
Expand All @@ -209,8 +219,8 @@ def test_delete_representative_solution_set(client: TestClient, session_and_user
f"/problem/representative_solution_set/{repr_metadata.id}",
headers={"Authorization": f"Bearer {access_token}"},
)
assert response.status_code == 200
assert response.json()["detail"] == "Deleted successfully"
assert response.status_code == 204
assert response.content == b""

# Verify DB deletion
deleted_set = session.get(RepresentativeNonDominatedSolutions, repr_metadata.id)
Expand Down
Loading