diff --git a/deployments/api/src/stitch/api/db/errors.py b/deployments/api/src/stitch/api/db/errors.py new file mode 100644 index 0000000..c5a7373 --- /dev/null +++ b/deployments/api/src/stitch/api/db/errors.py @@ -0,0 +1,10 @@ +from stitch.api.errors import StitchAPIError + + +class ResourceNotFoundError(StitchAPIError): ... + + +class ResourceIntegrityError(StitchAPIError): ... + + +class InvalidActionError(StitchAPIError): ... diff --git a/deployments/api/src/stitch/api/db/init_job.py b/deployments/api/src/stitch/api/db/init_job.py index 5a256c2..596413a 100644 --- a/deployments/api/src/stitch/api/db/init_job.py +++ b/deployments/api/src/stitch/api/db/init_job.py @@ -281,6 +281,9 @@ def create_seed_sources(): GemSourceModel.from_entity( GemData(name="North Sea Platform", country="GBR", lat=57.5, lon=1.5) ), + GemSourceModel.from_entity( + GemData(name="Merge Target Field", country="YYZ", lat=13.37, lon=13.37) + ), ] for i, src in enumerate(gem_sources, start=1): src.id = i @@ -294,6 +297,13 @@ def create_seed_sources(): WMSourceModel.from_entity( WMData(field_name="Ghawar Field", field_country="SAU", production=500000.0) ), + WMSourceModel.from_entity( + WMData( + field_name="Merge Consumed Field", + field_country="YYZ", + production=1337.0, + ) + ), ] for i, src in enumerate(wm_sources, start=1): src.id = i @@ -324,6 +334,8 @@ def create_seed_resources(user: UserEntity) -> list[ResourceModel]: resources = [ ResourceModel.create(user, name="Multi-Source Asset", country="USA"), ResourceModel.create(user, name="Single Source Asset", country="GBR"), + ResourceModel.create(user, name="Merge Target", country="YYZ"), + ResourceModel.create(user, name="Merge Consumed", country="YYZ"), ] for i, res in enumerate(resources, start=1): res.id = i @@ -342,6 +354,8 @@ def create_seed_memberships( MembershipModel.create(user, resources[0], "wm", wm_sources[0].id), MembershipModel.create(user, resources[0], "rmi", rmi_sources[0].id), MembershipModel.create(user, resources[1], "gem", gem_sources[1].id), + MembershipModel.create(user, resources[2], "gem", gem_sources[2].id), + MembershipModel.create(user, resources[3], "wm", wm_sources[2].id), ] for i, mem in enumerate(memberships, start=1): mem.id = i diff --git a/deployments/api/src/stitch/api/db/model/resource.py b/deployments/api/src/stitch/api/db/model/resource.py index 5493faa..5242f58 100644 --- a/deployments/api/src/stitch/api/db/model/resource.py +++ b/deployments/api/src/stitch/api/db/model/resource.py @@ -1,6 +1,7 @@ from collections import defaultdict from enum import StrEnum from sqlalchemy import ( + Enum, ForeignKey, Index, Integer, @@ -47,7 +48,9 @@ class MembershipModel(TimestampMixin, UserAuditMixin, Base): String(10), nullable=False ) # "gem" | "wm" source_pk: Mapped[int] = mapped_column(PORTABLE_BIGINT, nullable=False) - status: Mapped[MembershipStatus] + status: Mapped[MembershipStatus] = mapped_column( + Enum(MembershipStatus), nullable=False + ) @classmethod def create( diff --git a/deployments/api/src/stitch/api/db/resource_actions.py b/deployments/api/src/stitch/api/db/resource_actions.py index c0d19bc..a21a426 100644 --- a/deployments/api/src/stitch/api/db/resource_actions.py +++ b/deployments/api/src/stitch/api/db/resource_actions.py @@ -8,7 +8,7 @@ from sqlalchemy.orm import selectinload from starlette.status import HTTP_404_NOT_FOUND -from stitch.api.db.model.sources import SOURCE_TABLES, SourceModel +from .model.sources import SOURCE_TABLES, SourceModel from stitch.api.auth import CurrentUser from stitch.api.entities import ( CreateResource, @@ -19,10 +19,13 @@ SourceKey, ) +from .errors import InvalidActionError, ResourceNotFoundError, ResourceIntegrityError + from .model import ( CCReservoirsSourceModel, GemSourceModel, MembershipModel, + MembershipStatus, RMIManualSourceModel, ResourceModel, WMSourceModel, @@ -154,3 +157,107 @@ async def create_source_data(session: AsyncSession, data: CreateSourceData): rmi={rmi.id: rmi.as_entity() for rmi in rmis}, cc={cc.id: cc.as_entity() for cc in ccs}, ) + + +async def merge_resources( + session: AsyncSession, + user: CurrentUser, + resource_ids: Sequence[int], +) -> Resource: + """ + Stub "merge" behavior: + - Treat ids[0] as the canonical/target resource. + - Update all resources in ids[1:] to have repointed_id = ids[0]. + + NOTE: This only updates the resource table repointing field (no membership/source consolidation). + """ + # preserve order but drop duplicates + unique_ids = list(dict.fromkeys(resource_ids)) + if len(unique_ids) < 2: + raise InvalidActionError( + f"Merging only possible between multiple ids: received: {unique_ids}" + ) + + stmt = select(ResourceModel).where(ResourceModel.id.in_(unique_ids)) + + results = (await session.scalars(stmt)).all() + missing_ids = set(unique_ids).difference(set([r.id for r in results])) + if len(missing_ids) > 0: + msg = f"Resources not found for ids: [{','.join(map(str, missing_ids))}]" + raise ResourceNotFoundError(msg) + + if len(repointed := [r for r in results if r.repointed_id is not None]) > 0: + reprs = map(repr, repointed) + msg = f"Repointed: [{','.join(reprs)}]" + raise ResourceIntegrityError( + f"Cannot merge any resource that has already been merged. {msg}" + ) + + # all ids exist, none have already been repointed + new_resource = ResourceModel.create(created_by=user) + session.add(new_resource) + await session.flush() + + # all results are still members of the session + # changes will be picked up on commit + for res in results: + res.repointed_id = new_resource.id + + _ = await _repoint_memberships(session, user, new_resource.id, unique_ids) + + # Return the canonical resource entity + await session.refresh(new_resource, ["memberships"]) + return await resource_model_to_entity(session, new_resource) + + +async def _repoint_memberships( + session: AsyncSession, + user: CurrentUser, + to_id: int, + from_ids: Sequence[int], +): + """Create new memberships pointing to a different resource. + + Collect all memberships whose `resource_id` is in the `from_resoure_ids` argument. For each of these, create + a new membership where `resource_id` = `to_resource_id`. + + This all takes place after a `merge_resources` operation where a new ResourceModel is created. + + Args: + session: the db session + user: the logged in user + to_id: the new resource id + from_ids: the original resource_ids + + Returns: + Sequence of newly created `MembershipModel` objects. + """ + res = await session.get(ResourceModel, to_id) + if res is None: + raise ResourceNotFoundError(f"No resource found for id = {to_id}.") + + existing_memberships = ( + await session.scalars( + select(MembershipModel).where(MembershipModel.resource_id.in_(from_ids)) + ) + ).all() + + # TODO: any integrity checks? What constitutes an invalid state at this point + + # create new memberships pointing to the new resource + new_memberships: list[MembershipModel] = [] + for mem in existing_memberships: + # set status on + new_memberships.append( + MembershipModel.create( + created_by=user, + resource=res, + source=mem.source, + source_pk=mem.source_pk, + status=mem.status, + ) + ) + if mem.status == MembershipStatus.ACTIVE: + mem.status = MembershipStatus.INACTIVE + session.add_all(new_memberships) + return new_memberships diff --git a/deployments/api/src/stitch/api/errors.py b/deployments/api/src/stitch/api/errors.py new file mode 100644 index 0000000..ed32b7a --- /dev/null +++ b/deployments/api/src/stitch/api/errors.py @@ -0,0 +1 @@ +class StitchAPIError(Exception): ... diff --git a/deployments/api/src/stitch/api/routers/resources.py b/deployments/api/src/stitch/api/routers/resources.py index 8d63c15..20a7f31 100644 --- a/deployments/api/src/stitch/api/routers/resources.py +++ b/deployments/api/src/stitch/api/routers/resources.py @@ -1,12 +1,17 @@ +import logging + from collections.abc import Sequence -from fastapi import APIRouter +from fastapi import APIRouter, HTTPException + +from pydantic import BaseModel from stitch.api.db import resource_actions from stitch.api.db.config import UnitOfWorkDep from stitch.api.auth import CurrentUser from stitch.api.entities import CreateResource, Resource +logger = logging.getLogger(__name__) router = APIRouter( prefix="/resources", @@ -33,3 +38,42 @@ async def create_resource( return await resource_actions.create( session=uow.session, user=user, resource=resource_in ) + + +class MergeRequest(BaseModel): + resource_ids: list[int] + + +@router.post("/merge", response_model=Resource) +async def merge_resources_endpoint( + *, uow: UnitOfWorkDep, user: CurrentUser, payload: MergeRequest +) -> Resource: + """ + Merge multiple resources into one (STUB): + repoint resource_ids[1:] -> resource_ids[0] + """ + ids = payload.resource_ids + # preserve order but drop duplicates + unique_ids = list(dict.fromkeys(ids)) + if len(unique_ids) < 2: + raise HTTPException( + status_code=400, detail="Provide at least 2 unique resource IDs" + ) + + logger.info( + "Merge requested by user=%s for resource_ids=%s", + getattr(user, "sub", ""), + unique_ids, + ) + + try: + return await resource_actions.merge_resources( + session=uow.session, + user=user, + resource_ids=unique_ids, + ) + except HTTPException: + raise + except Exception as exc: + logger.exception("Error while merging resources %s: %s", unique_ids, exc) + raise HTTPException(status_code=500, detail="Internal error during merge")