diff --git a/deployments/api/pyproject.toml b/deployments/api/pyproject.toml index 3aed224..869825b 100644 --- a/deployments/api/pyproject.toml +++ b/deployments/api/pyproject.toml @@ -12,6 +12,7 @@ dependencies = [ "pydantic-settings>=2.12.0", "sqlalchemy>=2.0.44", "stitch-auth", + "stitch-models", ] [project.scripts] @@ -41,3 +42,4 @@ addopts = ["-v", "--strict-markers", "--tb=short"] [tool.uv.sources] stitch-auth = { workspace = true } +stitch-models = { workspace = true } diff --git a/deployments/api/src/stitch/api/db/init_job.py b/deployments/api/src/stitch/api/db/init_job.py index 5a256c2..06b88f5 100644 --- a/deployments/api/src/stitch/api/db/init_job.py +++ b/deployments/api/src/stitch/api/db/init_job.py @@ -12,20 +12,12 @@ from sqlalchemy.orm import Session from stitch.api.db.model import ( - CCReservoirsSourceModel, - GemSourceModel, - MembershipModel, - RMIManualSourceModel, ResourceModel, StitchBase, UserModel, - WMSourceModel, ) from stitch.api.entities import ( - GemData, - RMIManualData, User as UserEntity, - WMData, ) """ @@ -273,81 +265,16 @@ def create_dev_user() -> UserModel: ) -def create_seed_sources(): - gem_sources = [ - GemSourceModel.from_entity( - GemData(name="Permian Basin Field", country="USA", lat=31.8, lon=-102.3) - ), - GemSourceModel.from_entity( - GemData(name="North Sea Platform", country="GBR", lat=57.5, lon=1.5) - ), - ] - for i, src in enumerate(gem_sources, start=1): - src.id = i - - wm_sources = [ - WMSourceModel.from_entity( - WMData( - field_name="Eagle Ford Shale", field_country="USA", production=125000.5 - ) - ), - WMSourceModel.from_entity( - WMData(field_name="Ghawar Field", field_country="SAU", production=500000.0) - ), - ] - for i, src in enumerate(wm_sources, start=1): - src.id = i - - rmi_sources = [ - RMIManualSourceModel.from_entity( - RMIManualData( - name_override="Custom Override Name", - gwp=25.5, - gor=0.45, - country="CAN", - latitude=56.7, - longitude=-111.4, - ) - ), - ] - for i, src in enumerate(rmi_sources, start=1): - src.id = i - - # CC Reservoir sources are intentionally omitted from the dev seed profile; - # the CCReservoirsSourceModel table is still created from SQLAlchemy metadata. - cc_sources: list[CCReservoirsSourceModel] = [] - - return gem_sources, wm_sources, rmi_sources, cc_sources - - def create_seed_resources(user: UserEntity) -> list[ResourceModel]: resources = [ - ResourceModel.create(user, name="Multi-Source Asset", country="USA"), - ResourceModel.create(user, name="Single Source Asset", country="GBR"), + ResourceModel.create(user, name="Resource Foo01"), + ResourceModel.create(user, name="Resource Bar01"), ] for i, res in enumerate(resources, start=1): res.id = i return resources -def create_seed_memberships( - user: UserEntity, - resources: list[ResourceModel], - gem_sources: list[GemSourceModel], - wm_sources: list[WMSourceModel], - rmi_sources: list[RMIManualSourceModel], -) -> list[MembershipModel]: - memberships = [ - MembershipModel.create(user, resources[0], "gem", gem_sources[0].id), - MembershipModel.create(user, resources[0], "wm", wm_sources[0].id), - MembershipModel.create(user, resources[0], "rmi", rmi_sources[0].id), - MembershipModel.create(user, resources[1], "gem", gem_sources[1].id), - ] - for i, mem in enumerate(memberships, start=1): - mem.id = i - return memberships - - def reset_sequences(engine, tables: Iterable[str]) -> None: with engine.begin() as conn: for t in tables: @@ -376,34 +303,15 @@ def seed_dev(engine) -> None: name=user_model.name, ) - dev_entity = UserEntity( - id=dev_model.id, - sub=dev_model.sub, - email=dev_model.email, - name=dev_model.name, - ) - - gem_sources, wm_sources, rmi_sources, cc_sources = create_seed_sources() - session.add_all(gem_sources + wm_sources + rmi_sources + cc_sources) - resources = create_seed_resources(user_entity) - resources = create_seed_resources(dev_entity) session.add_all(resources) - memberships = create_seed_memberships( - user_entity, resources, gem_sources, wm_sources, rmi_sources - ) - session.add_all(memberships) - session.commit() reset_sequences( engine, tables=[ "users", - "gem_sources", - "wm_sources", - "rmi_manual_sources", "resources", "memberships", ], diff --git a/deployments/api/src/stitch/api/db/model/__init__.py b/deployments/api/src/stitch/api/db/model/__init__.py index d3f74ab..9a20770 100644 --- a/deployments/api/src/stitch/api/db/model/__init__.py +++ b/deployments/api/src/stitch/api/db/model/__init__.py @@ -1,21 +1,11 @@ from .common import Base as StitchBase -from .sources import ( - GemSourceModel, - RMIManualSourceModel, - CCReservoirsSourceModel, - WMSourceModel, -) from .resource import MembershipStatus, MembershipModel, ResourceModel from .user import User as UserModel __all__ = [ - "CCReservoirsSourceModel", - "GemSourceModel", "MembershipModel", "MembershipStatus", - "RMIManualSourceModel", "ResourceModel", "StitchBase", "UserModel", - "WMSourceModel", ] diff --git a/deployments/api/src/stitch/api/db/model/mixins.py b/deployments/api/src/stitch/api/db/model/mixins.py index b579582..e689933 100644 --- a/deployments/api/src/stitch/api/db/model/mixins.py +++ b/deployments/api/src/stitch/api/db/model/mixins.py @@ -1,11 +1,10 @@ from datetime import datetime from typing import Any, ClassVar, Generic, TypeVar, get_args, get_origin -from pydantic import TypeAdapter +from pydantic import BaseModel, TypeAdapter from sqlalchemy import DateTime, ForeignKey, String, func from sqlalchemy.ext.hybrid import hybrid_property from sqlalchemy.orm import Mapped, declarative_mixin, mapped_column -from stitch.api.entities import SourceBase from .types import StitchJson @@ -30,7 +29,7 @@ class UserAuditMixin: last_updated_by_id: Mapped[int] = mapped_column(ForeignKey("users.id")) -TPayload = TypeVar("TPayload", bound=SourceBase) +TPayload = TypeVar("TPayload", bound=BaseModel) def _extract_payload_type(cls: type) -> type | None: @@ -63,7 +62,9 @@ def payload(self) -> TPayload: @payload.inplace.setter def _payload_setter(self, value: TPayload): - self.source = value.source + # Domain-agnostic: if the payload has a `source` attribute, keep it; + # otherwise set a neutral default. + self.source = getattr(value, "source", "unknown") self._payload_data = value.model_dump(mode="json") @payload.inplace.expression diff --git a/deployments/api/src/stitch/api/db/model/resource.py b/deployments/api/src/stitch/api/db/model/resource.py index 5493faa..48f0a97 100644 --- a/deployments/api/src/stitch/api/db/model/resource.py +++ b/deployments/api/src/stitch/api/db/model/resource.py @@ -1,4 +1,3 @@ -from collections import defaultdict from enum import StrEnum from sqlalchemy import ( ForeignKey, @@ -12,13 +11,8 @@ from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import Mapped, mapped_column, relationship -from .sources import ( - SOURCE_TABLES, - SourceKey, - SourceModel, - SourceModelData, -) -from stitch.api.entities import IdType, User as UserEntity +from stitch.models.types import IdType +from stitch.api.entities import User as UserEntity from .common import Base from .mixins import TimestampMixin, UserAuditMixin from .types import PORTABLE_BIGINT @@ -43,9 +37,7 @@ class MembershipModel(TimestampMixin, UserAuditMixin, Base): ) id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) resource_id: Mapped[int] = mapped_column(ForeignKey("resources.id"), nullable=False) - source: Mapped[SourceKey] = mapped_column( - String(10), nullable=False - ) # "gem" | "wm" + source: Mapped[str] = mapped_column(String(10), nullable=False) # "gem" | "wm" source_pk: Mapped[int] = mapped_column(PORTABLE_BIGINT, nullable=False) status: Mapped[MembershipStatus] @@ -54,14 +46,14 @@ def create( cls, created_by: UserEntity, resource: "ResourceModel", - source: SourceKey, + source: str, source_pk: IdType, status: MembershipStatus = MembershipStatus.ACTIVE, ): model = cls( resource_id=resource.id, source=source, - source_pk=str(source_pk), + source_pk=int(source_pk), status=status, created_by_id=created_by.id, last_updated_by_id=created_by.id, @@ -98,23 +90,6 @@ class ResourceModel(TimestampMixin, UserAuditMixin, Base): # and configure the appropriate SQL statement to load the membership objects memberships: Mapped[list[MembershipModel]] = relationship() - async def get_source_data(self, session: AsyncSession): - pks_by_src: dict[SourceKey, set[int]] = defaultdict(set) - for mem in self.memberships: - if mem.status == MembershipStatus.ACTIVE: - pks_by_src[mem.source].add(mem.source_pk) - - results: dict[SourceKey, dict[IdType, SourceModel]] = defaultdict(dict) - for src, pks in pks_by_src.items(): - model_cls = SOURCE_TABLES.get(src) - if model_cls is None: - continue - stmt = select(model_cls).where(model_cls.id.in_(pks)) - for src_model in await session.scalars(stmt): - results[src][src_model.id] = src_model - - return SourceModelData(**results) - async def get_root(self, session: AsyncSession): root = await session.scalar(self.__class__._root_select(self.id)) if root is None: diff --git a/deployments/api/src/stitch/api/db/model/sources.py b/deployments/api/src/stitch/api/db/model/sources.py deleted file mode 100644 index 6b8e5af..0000000 --- a/deployments/api/src/stitch/api/db/model/sources.py +++ /dev/null @@ -1,141 +0,0 @@ -from typing_extensions import Self - -from collections.abc import Mapping, MutableMapping -from typing import Final, Generic, TypeVar, TypedDict, get_args, get_origin -from pydantic import BaseModel -from sqlalchemy import CheckConstraint, inspect -from sqlalchemy.orm import Mapped, mapped_column -from .common import Base -from .types import PORTABLE_BIGINT, StitchJson -from stitch.api.entities import ( - CCReservoirsSource, - GemSource, - IdType, - RMIManualSource, - SourceKey, - WMData, - GemData, - RMIManualData, - CCReservoirsData, - WMSource, -) - - -def float_constraint( - colname: str, min_: float | None = None, max_: float | None = None -) -> CheckConstraint: - min_str = f"{colname} >= {min_}" if min_ is not None else None - max_str = f"{colname} <= {max_}" if max_ is not None else None - expr = " AND ".join(filter(None, (min_str, max_str))) - return CheckConstraint(expr) - - -def lat_constraints(colname: str): - return float_constraint(colname, -90, 90) - - -def lon_constraints(colname: str): - return float_constraint(colname, -180, 180) - - -TModelIn = TypeVar("TModelIn", bound=BaseModel) -TModelOut = TypeVar("TModelOut", bound=BaseModel) - - -class SourceBase(Base, Generic[TModelIn, TModelOut]): - __abstract__ = True - __entity_class_in__: type[TModelIn] - __entity_class_out__: type[TModelOut] - - id: Mapped[int] = mapped_column( - PORTABLE_BIGINT, primary_key=True, autoincrement=True - ) - - def __init_subclass__(cls, **kwargs) -> None: - super().__init_subclass__(**kwargs) - for base in getattr(cls, "__orig_bases__", ()): - if get_origin(base) is SourceBase: - args = get_args(base) - if len(args) >= 2: - if isinstance(args[0], type): - cls.__entity_class_in__ = args[0] - if isinstance(args[1], type): - cls.__entity_class_out__ = args[1] - break - - def as_entity(self): - return self.__entity_class_out__.model_validate(self) - - @classmethod - def from_entity(cls, entity: TModelIn) -> Self: - mapper = inspect(cls) - column_keys = {col.key for col in mapper.columns} - filtered = {k: v for k, v in entity.model_dump().items() if k in column_keys} - return cls(**filtered) - - -class GemSourceModel(SourceBase[GemData, GemSource]): - __tablename__ = "gem_sources" - - name: Mapped[str] - country: Mapped[str] - lat: Mapped[float] = mapped_column(lat_constraints("lat")) - lon: Mapped[float] = mapped_column(lon_constraints("lon")) - - -class WMSourceModel(SourceBase[WMData, WMSource]): - __tablename__ = "wm_sources" - - field_name: Mapped[str] - field_country: Mapped[str] - production: Mapped[float] - - -class RMIManualSourceModel(SourceBase[RMIManualData, RMIManualSource]): - __tablename__ = "rmi_manual_sources" - - name_override: Mapped[str | None] - gwp: Mapped[float | None] - gor: Mapped[float | None | None] = mapped_column( - float_constraint("gor", 0, 1), nullable=True - ) - country: Mapped[str | None] - latitude: Mapped[float | None] = mapped_column( - lat_constraints("latitude"), nullable=True - ) - longitude: Mapped[float | None] = mapped_column( - lon_constraints("longitude"), nullable=True - ) - - -class CCReservoirsSourceModel(SourceBase[CCReservoirsData, CCReservoirsSource]): - __tablename__ = "cc_reservoirs_sources" - - name: Mapped[str] - basin: Mapped[str] - depth: Mapped[float] - geofence: Mapped[list[tuple[float, float]]] = mapped_column(StitchJson()) - - -SourceModel = ( - GemSourceModel | WMSourceModel | RMIManualSourceModel | CCReservoirsSourceModel -) -SourceModelCls = type[SourceModel] - -SOURCE_TABLES: Final[Mapping[SourceKey, SourceModelCls]] = { - "gem": GemSourceModel, - "wm": WMSourceModel, - "rmi": RMIManualSourceModel, - "cc": CCReservoirsSourceModel, -} - - -class SourceModelData(TypedDict, total=False): - gem: MutableMapping[IdType, GemSourceModel] - wm: MutableMapping[IdType, WMSourceModel] - cc: MutableMapping[IdType, CCReservoirsSourceModel] - rmi: MutableMapping[IdType, RMIManualSourceModel] - - -def empty_source_model_data(): - return SourceModelData(gem={}, wm={}, cc={}, rmi={}) diff --git a/deployments/api/src/stitch/api/db/resource_actions.py b/deployments/api/src/stitch/api/db/resource_actions.py index c0d19bc..1ad9358 100644 --- a/deployments/api/src/stitch/api/db/resource_actions.py +++ b/deployments/api/src/stitch/api/db/resource_actions.py @@ -1,93 +1,38 @@ import asyncio -from collections import defaultdict -from collections.abc import Mapping, Sequence +from collections.abc import Sequence from functools import partial from fastapi import HTTPException from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy.orm import selectinload from starlette.status import HTTP_404_NOT_FOUND - -from stitch.api.db.model.sources import SOURCE_TABLES, SourceModel +from stitch.api.resources.entities import CreateResource, Resource from stitch.api.auth import CurrentUser -from stitch.api.entities import ( - CreateResource, - CreateResourceSourceData, - CreateSourceData, - Resource, - SourceData, - SourceKey, -) from .model import ( - CCReservoirsSourceModel, - GemSourceModel, - MembershipModel, - RMIManualSourceModel, ResourceModel, - WMSourceModel, ) -async def get_or_create_source_models( - session: AsyncSession, - data: CreateResourceSourceData, -) -> Mapping[SourceKey, Sequence[SourceModel]]: - result: dict[SourceKey, list[SourceModel]] = defaultdict(list) - for key, model_cls in SOURCE_TABLES.items(): - for item in data.get(key): - if isinstance(item, int): - src_model = await session.get(model_cls, item) - if src_model is None: - continue - result[key].append(src_model) - else: - result[key].append(model_cls.from_entity(item)) - return result - - -def resource_model_to_empty_entity(model: ResourceModel): - return Resource( - id=model.id, - name=model.name, - country=model.country, - source_data=SourceData(), - constituents=[], - created=model.created, - updated=model.updated, - ) - - async def resource_model_to_entity( session: AsyncSession, model: ResourceModel ) -> Resource: - source_model_data = await model.get_source_data(session) - source_data = SourceData.model_validate(source_model_data) + # Domain-agnostic: constituents are just ids, and there is no source_data. constituent_models = await ResourceModel.get_constituents_by_root_id( session, model.id ) - constituents = [ - resource_model_to_empty_entity(cm) - for cm in constituent_models - if cm.id != model.id - ] + constituent_ids = [m.id for m in constituent_models if m.id != model.id] return Resource( id=model.id, name=model.name, - country=model.country, - source_data=source_data, - constituents=constituents, - created=model.created, - updated=model.updated, + repointed_to=model.repointed_id, + constituents=constituent_ids, + created=str(model.created) if getattr(model, "created", None) else None, + updated=str(model.updated) if getattr(model, "updated", None) else None, ) async def get_all(session: AsyncSession) -> Sequence[Resource]: - stmt = ( - select(ResourceModel) - .where(ResourceModel.repointed_id.is_(None)) - .options(selectinload(ResourceModel.memberships)) - ) + stmt = select(ResourceModel).where(ResourceModel.repointed_id.is_(None)) models = (await session.scalars(stmt)).all() fn = partial(resource_model_to_entity, session) return await asyncio.gather(*[fn(m) for m in models]) @@ -99,58 +44,18 @@ async def get(session: AsyncSession, id: int): raise HTTPException( status_code=HTTP_404_NOT_FOUND, detail=f"No Resource with id `{id}` found." ) - await session.refresh(model, ["memberships"]) return await resource_model_to_entity(session, model) async def create(session: AsyncSession, user: CurrentUser, resource: CreateResource): """ - Here we create a resource either from new source data or existing source data. It's also possible - to create an empty resource with no reference to source data. - - - create the resource - - create the sources - - create membership + Domain-agnostic create: + - create the resource row only """ model = ResourceModel.create( - created_by=user, name=resource.name, country=resource.country + created_by=user, + name=resource.name, ) session.add(model) - if resource.source_data: - src_model_groups = await get_or_create_source_models( - session, resource.source_data - ) - for src_key, src_models in src_model_groups.items(): - session.add_all(src_models) - await session.flush() - for src_model in src_models: - session.add( - MembershipModel.create( - created_by=user, - resource=model, - source=src_key, - source_pk=src_model.id, - ) - ) await session.flush() - await session.refresh(model, ["memberships"]) return await resource_model_to_entity(session, model) - - -async def create_source_data(session: AsyncSession, data: CreateSourceData): - """ - For bulk inserting data into source tables. - """ - gems = tuple(GemSourceModel.from_entity(gem) for gem in data.gem) - wms = tuple(WMSourceModel.from_entity(wm) for wm in data.wm) - rmis = tuple(RMIManualSourceModel.from_entity(rmi) for rmi in data.rmi) - ccs = tuple(CCReservoirsSourceModel.from_entity(cc) for cc in data.cc) - - session.add_all(gems + wms + rmis + ccs) - await session.flush() - return SourceData( - gem={g.id: g.as_entity() for g in gems}, - wm={wm.id: wm.as_entity() for wm in wms}, - rmi={rmi.id: rmi.as_entity() for rmi in rmis}, - cc={cc.id: cc.as_entity() for cc in ccs}, - ) diff --git a/deployments/api/src/stitch/api/entities.py b/deployments/api/src/stitch/api/entities.py index 337a5ee..fecbfab 100644 --- a/deployments/api/src/stitch/api/entities.py +++ b/deployments/api/src/stitch/api/entities.py @@ -1,165 +1,4 @@ -from collections.abc import Sequence -from datetime import datetime -from typing import ( - Annotated, - Generic, - Literal, - Mapping, - Protocol, - TypeVar, - runtime_checkable, -) -from uuid import UUID -from pydantic import BaseModel, ConfigDict, EmailStr, Field - -IdType = int | str | UUID - - -@runtime_checkable -class HasId(Protocol): - @property - def id(self) -> IdType: ... - - -GEM_SRC = Literal["gem"] -WM_SRC = Literal["wm"] -RMI_SRC = Literal["rmi"] -CC_SRC = Literal["cc"] - -SourceKey = GEM_SRC | WM_SRC | RMI_SRC | CC_SRC - -TSourceKey = TypeVar("TSourceKey", bound=SourceKey) - - -class Timestamped(BaseModel): - created: datetime = Field(default_factory=datetime.now) - updated: datetime = Field(default_factory=datetime.now) - - -class Identified(BaseModel): - id: IdType - - -class SourceBase(BaseModel, Generic[TSourceKey]): - source: TSourceKey - id: IdType - - -class SourceRef(BaseModel): - source: SourceKey - id: int - - -# The sources will come in and be initially stored in a raw table. -# That raw table will be an append-only table. -# We'll translate that data into one of the below structures, so each source will have a `UUID` or similar that -# references their id in the "raw" table. -# When pulling into the internal "sources" table, each will get a new unique id which is what the memberships will reference - - -class GemData(BaseModel): - name: str - lat: float = Field(ge=-90, le=90) - lon: float = Field(ge=-180, le=180) - country: str - - -class GemSource(Identified, GemData): - model_config = ConfigDict(from_attributes=True) - - -class WMData(BaseModel): - field_name: str - field_country: str - production: float - - -class WMSource(Identified, WMData): - model_config = ConfigDict(from_attributes=True) - - -class RMIManualData(BaseModel): - name_override: str - gwp: float - gor: float = Field(gt=0, lt=1) - country: str - latitude: float = Field(ge=-90, le=90) - longitude: float = Field(ge=-180, le=180) - - -class RMIManualSource(Identified, RMIManualData): - model_config = ConfigDict(from_attributes=True) - - -class CCReservoirsData(BaseModel): - name: str - basin: str - depth: float - geofence: Sequence[tuple[float, float]] - - -class CCReservoirsSource(Identified, CCReservoirsData): - model_config = ConfigDict(from_attributes=True) - - -OGSISourcePayload = Annotated[ - GemSource | WMSource | RMIManualSource | CCReservoirsSource, - Field(discriminator="source"), -] - - -class SourceData(BaseModel): - gem: Mapping[IdType, GemSource] = Field(default_factory=dict) - wm: Mapping[IdType, WMSource] = Field(default_factory=dict) - rmi: Mapping[IdType, RMIManualSource] = Field(default_factory=dict) - cc: Mapping[IdType, CCReservoirsSource] = Field(default_factory=dict) - - -class CreateSourceData(BaseModel): - gem: Sequence[GemData] = Field(default_factory=list) - wm: Sequence[WMData] = Field(default_factory=list) - rmi: Sequence[RMIManualData] = Field(default_factory=list) - cc: Sequence[CCReservoirsData] = Field(default_factory=list) - - -class CreateResourceSourceData(BaseModel): - """Allows for creating source data or referencing existing sources by ID. - - It can be used in isolation to insert source data or used with a new/existing resource to automatically add - memberships to the resource. - """ - - gem: Sequence[GemData | int] = Field(default_factory=list) - wm: Sequence[WMData | int] = Field(default_factory=list) - rmi: Sequence[RMIManualData | int] = Field(default_factory=list) - cc: Sequence[CCReservoirsData | int] = Field(default_factory=list) - - def get(self, key: SourceKey): - if key == "gem": - return self.gem - elif key == "wm": - return self.wm - elif key == "rmi": - return self.rmi - elif key == "cc": - return self.cc - raise ValueError(f"Unknown source key: {key}") - - -class ResourceBase(BaseModel): - name: str | None = Field(default=None) - country: str | None = Field(default=None) - repointed_to: "Resource | None" = Field(default=None) - - -class Resource(ResourceBase, Timestamped): - id: int - source_data: SourceData - constituents: Sequence["Resource"] - - -class CreateResource(ResourceBase): - source_data: CreateResourceSourceData | None +from pydantic import BaseModel, EmailStr, Field class User(BaseModel): @@ -168,6 +7,3 @@ class User(BaseModel): role: str | None = None email: EmailStr name: str - - -class SourceSelectionLogic(BaseModel): ... diff --git a/deployments/api/src/stitch/api/resources/entities.py b/deployments/api/src/stitch/api/resources/entities.py new file mode 100644 index 0000000..e1f1471 --- /dev/null +++ b/deployments/api/src/stitch/api/resources/entities.py @@ -0,0 +1,19 @@ +from __future__ import annotations + +from pydantic import BaseModel, ConfigDict, Field + + +class CreateResource(BaseModel): + """Domain-agnostic create model.""" + + model_config = ConfigDict(extra="forbid") + name: str | None = None + + +class Resource(BaseModel): + """Domain-agnostic read model.""" + + id: int + name: str | None = None + repointed_to: int | None = None + constituents: frozenset[int] = Field(default_factory=frozenset) diff --git a/deployments/api/src/stitch/api/routers/resources.py b/deployments/api/src/stitch/api/routers/resources.py index 8d63c15..a2e7d94 100644 --- a/deployments/api/src/stitch/api/routers/resources.py +++ b/deployments/api/src/stitch/api/routers/resources.py @@ -5,7 +5,7 @@ from stitch.api.db import resource_actions from stitch.api.db.config import UnitOfWorkDep from stitch.api.auth import CurrentUser -from stitch.api.entities import CreateResource, Resource +from stitch.api.resources.entities import CreateResource, Resource router = APIRouter( diff --git a/deployments/api/tests/conftest.py b/deployments/api/tests/conftest.py index 41ee121..1782c11 100644 --- a/deployments/api/tests/conftest.py +++ b/deployments/api/tests/conftest.py @@ -9,24 +9,13 @@ from stitch.api.db.config import UnitOfWork, get_uow from stitch.api.db.model import ( - CCReservoirsSourceModel, - GemSourceModel, - RMIManualSourceModel, StitchBase, UserModel, - WMSourceModel, ) from stitch.api.auth import get_current_user from stitch.api.entities import User from stitch.api.main import app -from .utils import ( - CC_DEFAULTS, - GEM_DEFAULTS, - RMI_DEFAULTS, - WM_DEFAULTS, -) - @pytest.fixture def anyio_backend() -> str: @@ -184,117 +173,3 @@ def override_get_current_user() -> User: base_url="http://test/api/v1", ) as ac: yield ac - - -@pytest.fixture -async def existing_gem_source( - seeded_integration_session: AsyncSession, -) -> GemSourceModel: - """Pre-create a GEM source in DB, return model with ID.""" - model = GemSourceModel( - name=GEM_DEFAULTS["name"], - lat=GEM_DEFAULTS["lat"], - lon=GEM_DEFAULTS["lon"], - country=GEM_DEFAULTS["country"], - ) - seeded_integration_session.add(model) - await seeded_integration_session.flush() - return model - - -@pytest.fixture -async def existing_wm_source( - seeded_integration_session: AsyncSession, -) -> WMSourceModel: - """Pre-create a WM source in DB, return model with ID.""" - model = WMSourceModel( - field_name=WM_DEFAULTS["field_name"], - field_country=WM_DEFAULTS["field_country"], - production=WM_DEFAULTS["production"], - ) - seeded_integration_session.add(model) - await seeded_integration_session.flush() - return model - - -@pytest.fixture -async def existing_rmi_source( - seeded_integration_session: AsyncSession, -) -> RMIManualSourceModel: - """Pre-create an RMI source in DB, return model with ID.""" - model = RMIManualSourceModel( - name_override=RMI_DEFAULTS["name_override"], - gwp=RMI_DEFAULTS["gwp"], - gor=RMI_DEFAULTS["gor"], - country=RMI_DEFAULTS["country"], - latitude=RMI_DEFAULTS["latitude"], - longitude=RMI_DEFAULTS["longitude"], - ) - seeded_integration_session.add(model) - await seeded_integration_session.flush() - return model - - -@pytest.fixture -async def existing_cc_source( - seeded_integration_session: AsyncSession, -) -> CCReservoirsSourceModel: - """Pre-create a CC source in DB, return model with ID.""" - model = CCReservoirsSourceModel( - name=CC_DEFAULTS["name"], - basin=CC_DEFAULTS["basin"], - depth=CC_DEFAULTS["depth"], - geofence=list(CC_DEFAULTS["geofence"]), - ) - seeded_integration_session.add(model) - await seeded_integration_session.flush() - return model - - -@pytest.fixture -async def existing_sources( - seeded_integration_session: AsyncSession, -) -> dict[str, list[int]]: - """Create 2 of each source type, return dict mapping source key to list of IDs.""" - session = seeded_integration_session - - gems = [ - GemSourceModel(name=f"GEM {i}", lat=45.0 + i, lon=-120.0 + i, country="USA") - for i in range(2) - ] - wms = [ - WMSourceModel( - field_name=f"WM Field {i}", field_country="USA", production=1000.0 * (i + 1) - ) - for i in range(2) - ] - rmis = [ - RMIManualSourceModel( - name_override=f"RMI {i}", - gwp=25.0, - gor=0.5, - country="USA", - latitude=40.0 + i, - longitude=-100.0 + i, - ) - for i in range(2) - ] - ccs = [ - CCReservoirsSourceModel( - name=f"CC Reservoir {i}", - basin="Permian", - depth=3000.0, - geofence=[(0.0, 0.0), (1.0, 0.0), (1.0, 1.0), (0.0, 1.0)], - ) - for i in range(2) - ] - - session.add_all(gems + wms + rmis + ccs) - await session.flush() - - return { - "gem": [g.id for g in gems], - "wm": [w.id for w in wms], - "rmi": [r.id for r in rmis], - "cc": [c.id for c in ccs], - } diff --git a/deployments/api/tests/db/test_resource_actions.py b/deployments/api/tests/db/test_resource_actions.py index 39538a2..a2a83b1 100644 --- a/deployments/api/tests/db/test_resource_actions.py +++ b/deployments/api/tests/db/test_resource_actions.py @@ -1,52 +1,25 @@ -"""Database integration tests for resource_actions module.""" +"""Database integration tests for domain-agnostic resource_actions.""" import pytest from fastapi import HTTPException -from sqlalchemy import func, select from sqlalchemy.ext.asyncio import AsyncSession from stitch.api.db import resource_actions -from stitch.api.db.model import ( - GemSourceModel, - MembershipModel, - ResourceModel, -) -from stitch.api.entities import CreateSourceData, GemData, User, WMData - -from tests.utils import ( - make_cc_data, - make_create_resource, - make_empty_resource, - make_gem_data, - make_resource_with_existing_ids, - make_resource_with_mixed_sources, - make_resource_with_new_sources, - make_rmi_data, - make_source_data, - make_wm_data, -) - - -class TestGetResourceActionUnit: ... - - -class TestCreateResourceActionUnit: ... - - -class TestCreateSourceDataActionUnit: ... +from stitch.api.db.model import ResourceModel +from stitch.api.entities import User +from tests.utils import make_create_resource, make_empty_resource class TestCreateResourceActionIntegration: """Integration tests for resource_actions.create() with real database.""" @pytest.mark.anyio - async def test_creates_resource_with_no_source_data( + async def test_creates_resource_with_minimal_payload( self, seeded_integration_session: AsyncSession, test_user: User, ): - """Resource with no source data persists correctly.""" - resource_in = make_empty_resource(name="Empty Resource", country="USA") + resource_in = make_empty_resource(name=None) result = await resource_actions.create( session=seeded_integration_session, @@ -55,221 +28,18 @@ async def test_creates_resource_with_no_source_data( ) assert result.id is not None - assert result.name == "Empty Resource" - assert result.country == "USA" + assert result.name is None db_resource = await seeded_integration_session.get(ResourceModel, result.id) assert db_resource is not None - assert db_resource.name == "Empty Resource" - - membership_count = ( - await seeded_integration_session.execute( - select(func.count()).select_from(MembershipModel) - ) - ).scalar() - assert membership_count == 0 - - @pytest.mark.anyio - async def test_creates_resource_with_new_gem_source( - self, - seeded_integration_session: AsyncSession, - test_user: User, - ): - """New GEM source creates resource, source, and membership.""" - resource_in = make_resource_with_new_sources( - gem=make_gem_data(name="Test GEM Field", lat=40.0, lon=-100.0).model, - name="With GEM", - ) - - result = await resource_actions.create( - session=seeded_integration_session, - user=test_user, - resource=resource_in.model, - ) - - assert result.id is not None - assert result.name == "With GEM" - - gem_sources = ( - ( - await seeded_integration_session.execute( - select(GemSourceModel).where( - GemSourceModel.name == "Test GEM Field" - ) - ) - ) - .scalars() - .all() - ) - assert len(gem_sources) == 1 - - memberships = ( - ( - await seeded_integration_session.execute( - select(MembershipModel).where( - MembershipModel.resource_id == result.id - ) - ) - ) - .scalars() - .all() - ) - assert len(memberships) == 1 - assert memberships[0].source == "gem" - - @pytest.mark.anyio - async def test_creates_resource_with_new_sources_all_types( - self, - seeded_integration_session: AsyncSession, - test_user: User, - ): - """Resource with all four source types creates correct memberships.""" - source_data = make_source_data( - gem=[make_gem_data(name="All Types GEM").model], - wm=[make_wm_data(field_name="All Types WM").model], - rmi=[make_rmi_data(name_override="All Types RMI").model], - cc=[make_cc_data(name="All Types CC").model], - ) - resource_in = make_create_resource( - name="All Sources Resource", - source_data=source_data, - ) - - result = await resource_actions.create( - session=seeded_integration_session, - user=test_user, - resource=resource_in.model, - ) - - memberships = ( - ( - await seeded_integration_session.execute( - select(MembershipModel).where( - MembershipModel.resource_id == result.id - ) - ) - ) - .scalars() - .all() - ) - - sources = {m.source for m in memberships} - assert sources == {"gem", "wm", "rmi", "cc"} - - @pytest.mark.anyio - async def test_creates_resource_with_existing_gem_id( - self, - seeded_integration_session: AsyncSession, - test_user: User, - existing_gem_source: GemSourceModel, - ): - """Existing source ID creates membership without new source record.""" - resource_in = make_resource_with_existing_ids( - gem_ids=[existing_gem_source.id], - name="With Existing GEM", - ) - - result = await resource_actions.create( - session=seeded_integration_session, - user=test_user, - resource=resource_in.model, - ) - - assert result.id is not None - - memberships = ( - ( - await seeded_integration_session.execute( - select(MembershipModel).where( - MembershipModel.resource_id == result.id - ) - ) - ) - .scalars() - .all() - ) - assert len(memberships) == 1 - assert memberships[0].source_pk == existing_gem_source.id - - @pytest.mark.anyio - async def test_creates_resource_with_mixed_new_and_existing( - self, - seeded_integration_session: AsyncSession, - test_user: User, - existing_gem_source: GemSourceModel, - ): - """Mix of new sources and existing IDs creates correct memberships.""" - new_gem = make_gem_data(name="Brand New GEM").model - - resource_in = make_resource_with_mixed_sources( - new_gem=new_gem, - existing_gem_ids=[existing_gem_source.id], - name="Mixed Sources", - ) - - result = await resource_actions.create( - session=seeded_integration_session, - user=test_user, - resource=resource_in.model, - ) - - memberships = ( - ( - await seeded_integration_session.execute( - select(MembershipModel).where( - MembershipModel.resource_id == result.id - ) - ) - ) - .scalars() - .all() - ) - - gem_memberships = [m for m in memberships if m.source == "gem"] - assert len(gem_memberships) == 2 - - @pytest.mark.anyio - async def test_creates_resource_with_multiple_sources_same_type( - self, - seeded_integration_session: AsyncSession, - test_user: User, - ): - """Multiple sources of same type creates multiple memberships.""" - gems = [make_gem_data(name=f"Multi GEM {i}").model for i in range(3)] - resource_in = make_resource_with_new_sources(gem=gems, name="Multiple GEMs") - - result = await resource_actions.create( - session=seeded_integration_session, - user=test_user, - resource=resource_in.model, - ) - - memberships = ( - ( - await seeded_integration_session.execute( - select(MembershipModel).where( - MembershipModel.resource_id == result.id - ) - ) - ) - .scalars() - .all() - ) - - assert len(memberships) == 3 - assert all(m.source == "gem" for m in memberships) @pytest.mark.anyio - async def test_nonexistent_source_id_creates_no_membership( + async def test_creates_resource_with_label( self, seeded_integration_session: AsyncSession, test_user: User, ): - """Nonexistent source ID is skipped, no membership created.""" - resource_in = make_resource_with_existing_ids( - gem_ids=[99999], - name="With Bad ID", - ) + resource_in = make_create_resource(name="Test Label") result = await resource_actions.create( session=seeded_integration_session, @@ -278,85 +48,27 @@ async def test_nonexistent_source_id_creates_no_membership( ) assert result.id is not None + assert result.name == "Test Label" - memberships = ( - ( - await seeded_integration_session.execute( - select(MembershipModel).where( - MembershipModel.resource_id == result.id - ) - ) - ) - .scalars() - .all() - ) - assert len(memberships) == 0 - - @pytest.mark.anyio - async def test_source_can_be_linked_to_multiple_resources( - self, - seeded_integration_session: AsyncSession, - test_user: User, - existing_gem_source: GemSourceModel, - ): - """Verify many-to-many: same source record can belong to multiple resources.""" - resource1_in = make_resource_with_existing_ids( - gem_ids=[existing_gem_source.id], - name="First Resource", - ) - result1 = await resource_actions.create( - session=seeded_integration_session, - user=test_user, - resource=resource1_in.model, - ) - - resource2_in = make_resource_with_existing_ids( - gem_ids=[existing_gem_source.id], - name="Second Resource", - ) - result2 = await resource_actions.create( - session=seeded_integration_session, - user=test_user, - resource=resource2_in.model, - ) - - memberships = ( - ( - await seeded_integration_session.execute( - select(MembershipModel).where( - MembershipModel.source == "gem", - MembershipModel.source_pk == existing_gem_source.id, - ) - ) - ) - .scalars() - .all() - ) - - assert len(memberships) == 2 - resource_ids = {m.resource_id for m in memberships} - assert resource_ids == {result1.id, result2.id} + db_resource = await seeded_integration_session.get(ResourceModel, result.id) + assert db_resource is not None + # DB ResourceModel may store `.name`; tolerate either while refactor settles. + assert getattr(db_resource, "name", None) in (None, "Test Label") class TestGetResourceActionIntegration: """Integration tests for resource_actions.get() with real database.""" @pytest.mark.anyio - async def test_get_returns_resource_with_populated_source_data( + async def test_get_returns_resource( self, seeded_integration_session: AsyncSession, test_user: User, ): - """GET returns resource with source_data populated.""" - resource_in = make_resource_with_new_sources( - gem=make_gem_data(name="Get Test GEM").model, - name="Get Test", - ) - created = await resource_actions.create( session=seeded_integration_session, user=test_user, - resource=resource_in.model, + resource=make_create_resource(name="Get Test").model, ) result = await resource_actions.get( @@ -364,61 +76,46 @@ async def test_get_returns_resource_with_populated_source_data( id=created.id, ) + assert result.id == created.id assert result.name == "Get Test" - assert len(result.source_data.gem) == 1 @pytest.mark.anyio async def test_get_nonexistent_raises_404( self, seeded_integration_session: AsyncSession, ): - """GET nonexistent resource raises HTTPException with 404.""" with pytest.raises(HTTPException) as exc_info: await resource_actions.get( session=seeded_integration_session, id=99999, ) - assert exc_info.value.status_code == 404 -class TestCreateSourceDataActionIntegration: - """Integration tests for resource_actions.create_source_data().""" +class TestListResourcesActionIntegration: + """Integration tests for resource_actions.get_all() with real database.""" @pytest.mark.anyio - async def test_bulk_creates_sources_returns_source_data_with_ids( + async def test_get_all_returns_sequence( self, seeded_integration_session: AsyncSession, + test_user: User, ): - """Bulk create sources returns SourceData with assigned IDs.""" - source_data = CreateSourceData( - gem=[ - GemData(name="Bulk GEM 1", lat=40.0, lon=-100.0, country="USA"), - GemData(name="Bulk GEM 2", lat=41.0, lon=-101.0, country="CAN"), - ], - wm=[ - WMData(field_name="Bulk WM", field_country="USA", production=5000.0), - ], + # create a couple resources + await resource_actions.create( + session=seeded_integration_session, + user=test_user, + resource=make_create_resource(name="A").model, ) - - result = await resource_actions.create_source_data( + await resource_actions.create( session=seeded_integration_session, - data=source_data, + user=test_user, + resource=make_create_resource(name="B").model, ) - assert len(result.gem) == 2 - assert len(result.wm) == 1 - - for gem in result.gem.values(): - assert gem.id is not None + results = await resource_actions.get_all(session=seeded_integration_session) + assert isinstance(results, (list, tuple)) + assert len(results) >= 2 - db_gems = ( - ( - await seeded_integration_session.execute( - select(GemSourceModel).where(GemSourceModel.name.like("Bulk GEM%")) - ) - ) - .scalars() - .all() - ) - assert len(db_gems) == 2 + labels = {r.name for r in results} + assert {"A", "B"} <= labels diff --git a/deployments/api/tests/routers/test_resources_integration.py b/deployments/api/tests/routers/test_resources_integration.py index 58f2d13..cbcd64f 100644 --- a/deployments/api/tests/routers/test_resources_integration.py +++ b/deployments/api/tests/routers/test_resources_integration.py @@ -5,12 +5,7 @@ from stitch.api.db.model import ResourceModel -from tests.utils import ( - make_empty_resource, - make_gem_data, - make_resource_with_new_sources, - make_wm_data, -) +from tests.utils import make_empty_resource, make_create_resource class TestResourcesIntegration: @@ -27,31 +22,20 @@ async def test_get_nonexistent_returns_404(self, integration_client): @pytest.mark.anyio async def test_create_resource_returns_resource(self, integration_client): """POST /resources/ returns the created resource.""" - resource_in = make_resource_with_new_sources( - gem=make_gem_data(name="GEM Integration Field", lat=40.0, lon=-100.0).model, - name="Integration Test Resource", - country="USA", - ) + resource_in = make_create_resource(name="Integration Test Resource") response = await integration_client.post("/resources/", json=resource_in.data) assert response.status_code == 200 data = response.json() assert data["name"] == "Integration Test Resource" - assert data["country"] == "USA" assert "id" in data assert data["id"] > 0 @pytest.mark.anyio async def test_create_and_get_resource(self, integration_client): """POST creates resource, GET retrieves it.""" - resource_in = make_resource_with_new_sources( - wm=make_wm_data( - field_name="WM Roundtrip Field", field_country="CAN", production=5000.0 - ).model, - name="Roundtrip Resource", - country="CAN", - ) + resource_in = make_create_resource(name="Roundtrip Resource") create_response = await integration_client.post( "/resources/", json=resource_in.data @@ -66,20 +50,13 @@ async def test_create_and_get_resource(self, integration_client): data = get_response.json() assert data["id"] == created_id assert data["name"] == "Roundtrip Resource" - assert data["country"] == "CAN" @pytest.mark.anyio async def test_create_persists_to_database( self, integration_client, integration_session_factory ): """POST resource is persisted and queryable directly.""" - resource_in = make_resource_with_new_sources( - gem=make_gem_data( - name="GEM Persist Field", lat=25.0, lon=-105.0, country="MEX" - ).model, - name="Persisted Resource", - country="MEX", - ) + resource_in = make_create_resource(name="Persisted Resource") response = await integration_client.post("/resources/", json=resource_in.data) @@ -94,12 +71,11 @@ async def test_create_persists_to_database( assert resource is not None assert resource.name == "Persisted Resource" - assert resource.country == "MEX" @pytest.mark.anyio async def test_create_with_minimal_data(self, integration_client): """POST /resources/ works with only required fields (no source data).""" - resource_in = make_empty_resource(name=None, country=None) + resource_in = make_empty_resource(name=None) response = await integration_client.post("/resources/", json=resource_in.data) @@ -107,4 +83,3 @@ async def test_create_with_minimal_data(self, integration_client): data = response.json() assert data["id"] > 0 assert data["name"] is None - assert data["country"] is None diff --git a/deployments/api/tests/routers/test_resources_unit.py b/deployments/api/tests/routers/test_resources_unit.py index b8841ce..87dd4dc 100644 --- a/deployments/api/tests/routers/test_resources_unit.py +++ b/deployments/api/tests/routers/test_resources_unit.py @@ -1,6 +1,5 @@ """Unit tests for resources router with mocked dependencies.""" -from datetime import datetime, timezone from unittest.mock import AsyncMock, patch import pytest @@ -8,31 +7,22 @@ from starlette.status import HTTP_404_NOT_FOUND from stitch.api.db.config import get_uow -from stitch.api.entities import Resource, SourceData from stitch.api.main import app -from tests.utils import ( - make_gem_data, - make_resource_with_new_sources, - make_wm_data, -) +from tests.utils import make_create_resource + +# Import the response model used by the router (domain-agnostic). +from stitch.api.resources.entities import Resource def make_resource( id: int = 1, - name: str = "Test Resource", - country: str = "USA", + name: str | None = "Test Resource", ) -> Resource: """Factory for creating Resource entities for tests.""" - now = datetime.now(timezone.utc) return Resource( id=id, name=name, - country=country, - source_data=SourceData(), - constituents=[], - created=now, - updated=now, ) @@ -88,14 +78,8 @@ class TestCreateResourceUnit: @pytest.mark.anyio async def test_creates_resource_with_user(self, async_client, mock_uow, test_user): """POST /resources/ calls repo.create with user and data.""" - expected = make_resource(id=123, name="New Resource", country="CAN") - resource_in = make_resource_with_new_sources( - gem=make_gem_data( - name="GEM Field", lat=45.0, lon=-120.0, country="CAN" - ).model, - name="New Resource", - country="CAN", - ) + expected = make_resource(id=123, name="New Resource") + resource_in = make_create_resource(name="New Resource") async def override_get_uow(): yield mock_uow @@ -116,12 +100,7 @@ async def override_get_uow(): async def test_returns_created_resource(self, async_client, mock_uow): """POST /resources/ returns the created resource entity.""" expected = make_resource(id=456, name="Created Resource") - resource_in = make_resource_with_new_sources( - wm=make_wm_data( - field_name="WM Field", field_country="USA", production=1000.0 - ).model, - name="Created Resource", - ) + resource_in = make_create_resource(name="Created Resource") async def override_get_uow(): yield mock_uow @@ -147,14 +126,6 @@ async def override_get_uow(): app.dependency_overrides[get_uow] = override_get_uow - response = await async_client.post( - "/resources/", - json={ - "name": "Test Resource", - "source_data": { - "gem": [{"invalid_field": "bad"}], - }, - }, - ) + response = await async_client.post("/resources/", json={"label": 123}) assert response.status_code == 422 diff --git a/deployments/api/tests/utils.py b/deployments/api/tests/utils.py index 490383c..90a52d2 100644 --- a/deployments/api/tests/utils.py +++ b/deployments/api/tests/utils.py @@ -6,20 +6,12 @@ from __future__ import annotations -from collections.abc import Sequence from dataclasses import dataclass from typing import Any, Generic, TypeVar from pydantic import BaseModel -from stitch.api.entities import ( - CCReservoirsData, - CreateResource, - CreateResourceSourceData, - GemData, - RMIManualData, - WMData, -) +from stitch.api.resources.entities import CreateResource T = TypeVar("T", bound=BaseModel) @@ -33,241 +25,23 @@ class FactoryResult(Generic[T]): @property def data(self) -> dict[str, Any]: """Return dict representation via model_dump().""" - return self.model.model_dump() - - -# Static defaults for each source type (no id - these are for creation) -GEM_DEFAULTS: dict[str, Any] = { - "name": "Default GEM Field", - "lat": 45.0, - "lon": -120.0, - "country": "USA", -} - -WM_DEFAULTS: dict[str, Any] = { - "field_name": "Default WM Field", - "field_country": "USA", - "production": 1000.0, -} - -RMI_DEFAULTS: dict[str, Any] = { - "name_override": "Default RMI", - "gwp": 25.0, - "gor": 0.5, - "country": "USA", - "latitude": 40.0, - "longitude": -100.0, -} - -CC_DEFAULTS: dict[str, Any] = { - "name": "Default CC Reservoir", - "basin": "Permian", - "depth": 3000.0, - "geofence": [(0.0, 0.0), (1.0, 0.0), (1.0, 1.0), (0.0, 1.0)], -} - - -def make_gem_data( - name: str = GEM_DEFAULTS["name"], - lat: float = GEM_DEFAULTS["lat"], - lon: float = GEM_DEFAULTS["lon"], - country: str = GEM_DEFAULTS["country"], -) -> FactoryResult[GemData]: - """Create a GemData with both model and dict representations.""" - return FactoryResult(model=GemData(name=name, lat=lat, lon=lon, country=country)) - - -def make_wm_data( - field_name: str = WM_DEFAULTS["field_name"], - field_country: str = WM_DEFAULTS["field_country"], - production: float = WM_DEFAULTS["production"], -) -> FactoryResult[WMData]: - """Create a WMData with both model and dict representations.""" - return FactoryResult( - model=WMData( - field_name=field_name, field_country=field_country, production=production - ) - ) - - -def make_rmi_data( - name_override: str = RMI_DEFAULTS["name_override"], - gwp: float = RMI_DEFAULTS["gwp"], - gor: float = RMI_DEFAULTS["gor"], - country: str = RMI_DEFAULTS["country"], - latitude: float = RMI_DEFAULTS["latitude"], - longitude: float = RMI_DEFAULTS["longitude"], -) -> FactoryResult[RMIManualData]: - """Create an RMIManualData with both model and dict representations.""" - return FactoryResult( - model=RMIManualData( - name_override=name_override, - gwp=gwp, - gor=gor, - country=country, - latitude=latitude, - longitude=longitude, - ) - ) - - -def make_cc_data( - name: str = CC_DEFAULTS["name"], - basin: str = CC_DEFAULTS["basin"], - depth: float = CC_DEFAULTS["depth"], - geofence: Sequence[tuple[float, float]] = CC_DEFAULTS["geofence"], -) -> FactoryResult[CCReservoirsData]: - """Create a CCReservoirsData with both model and dict representations.""" - return FactoryResult( - model=CCReservoirsData( - name=name, basin=basin, depth=depth, geofence=list(geofence) - ) - ) - - -def make_source_data( - gem: Sequence[GemData | int] | None = None, - wm: Sequence[WMData | int] | None = None, - rmi: Sequence[RMIManualData | int] | None = None, - cc: Sequence[CCReservoirsData | int] | None = None, -) -> FactoryResult[CreateResourceSourceData]: - """Create CreateResourceSourceData with both model and dict representations. - - Args: - gem: List of GemData models or existing source IDs - wm: List of WMData models or existing source IDs - rmi: List of RMIManualData models or existing source IDs - cc: List of CCReservoirsData models or existing source IDs - - Returns: - FactoryResult with model and data (dict) attributes - """ - return FactoryResult( - model=CreateResourceSourceData( - gem=list(gem or []), - wm=list(wm or []), - rmi=list(rmi or []), - cc=list(cc or []), - ) - ) + return self.model.model_dump(mode="json") def make_create_resource( - name: str | None = "Test Resource", - country: str | None = "USA", - source_data: CreateResourceSourceData - | FactoryResult[CreateResourceSourceData] - | None = None, + *, + name: str | None = None, ) -> FactoryResult[CreateResource]: - """Create a CreateResource with both model and dict representations. - - Args: - name: Resource name (optional) - country: Country code (optional) - source_data: Either a CreateResourceSourceData model, a FactoryResult, - or None for empty source data - - Returns: - FactoryResult with model and data (dict) attributes - """ - if source_data is None: - sd_model = None - elif isinstance(source_data, FactoryResult): - sd_model = source_data.model - else: - sd_model = source_data - - return FactoryResult( - model=CreateResource(name=name, country=country, source_data=sd_model) - ) + """Create a minimal, domain-agnostic CreateResource payload.""" + return FactoryResult(model=CreateResource(name=name)) # Convenience factory functions for common test scenarios def make_empty_resource( - name: str | None = "Empty Resource", - country: str | None = "USA", + *, + name: str | None = None, ) -> FactoryResult[CreateResource]: - """Create a resource with no source data.""" - return make_create_resource(name=name, country=country, source_data=None) - - -def make_resource_with_new_sources( - gem: GemData | Sequence[GemData] | None = None, - wm: WMData | Sequence[WMData] | None = None, - rmi: RMIManualData | Sequence[RMIManualData] | None = None, - cc: CCReservoirsData | Sequence[CCReservoirsData] | None = None, - name: str | None = "Resource with Sources", - country: str | None = "USA", -) -> FactoryResult[CreateResource]: - """Create a resource with new source data only (no existing IDs).""" - - def to_list(item: Any | Sequence[Any] | None) -> list[Any]: - if item is None: - return [] - if isinstance(item, (list, tuple)): - return list(item) - return [item] - - source_data = make_source_data( - gem=to_list(gem), - wm=to_list(wm), - rmi=to_list(rmi), - cc=to_list(cc), - ) - return make_create_resource(name=name, country=country, source_data=source_data) - - -def make_resource_with_existing_ids( - gem_ids: Sequence[int] | None = None, - wm_ids: Sequence[int] | None = None, - rmi_ids: Sequence[int] | None = None, - cc_ids: Sequence[int] | None = None, - name: str | None = "Resource with Existing Sources", - country: str | None = "USA", -) -> FactoryResult[CreateResource]: - """Create a resource referencing existing source IDs only.""" - source_data = make_source_data( - gem=list(gem_ids or []), - wm=list(wm_ids or []), - rmi=list(rmi_ids or []), - cc=list(cc_ids or []), - ) - return make_create_resource(name=name, country=country, source_data=source_data) - - -def make_resource_with_mixed_sources( - new_gem: GemData | Sequence[GemData] | None = None, - existing_gem_ids: Sequence[int] | None = None, - new_wm: WMData | Sequence[WMData] | None = None, - existing_wm_ids: Sequence[int] | None = None, - new_rmi: RMIManualData | Sequence[RMIManualData] | None = None, - existing_rmi_ids: Sequence[int] | None = None, - new_cc: CCReservoirsData | Sequence[CCReservoirsData] | None = None, - existing_cc_ids: Sequence[int] | None = None, - name: str | None = "Resource with Mixed Sources", - country: str | None = "USA", -) -> FactoryResult[CreateResource]: - """Create a resource with a mix of new source data and existing source IDs.""" - - def to_list(item: Any | Sequence[Any] | None) -> list[Any]: - if item is None: - return [] - if isinstance(item, (list, tuple)): - return list(item) - return [item] - - gem_items: list[GemData | int] = to_list(new_gem) + list(existing_gem_ids or []) - wm_items: list[WMData | int] = to_list(new_wm) + list(existing_wm_ids or []) - rmi_items: list[RMIManualData | int] = to_list(new_rmi) + list( - existing_rmi_ids or [] - ) - cc_items: list[CCReservoirsData | int] = to_list(new_cc) + list( - existing_cc_ids or [] - ) - - source_data = make_source_data( - gem=gem_items, wm=wm_items, rmi=rmi_items, cc=cc_items - ) - return make_create_resource(name=name, country=country, source_data=source_data) + """Alias for make_create_resource() kept for readability.""" + return make_create_resource(name=name) diff --git a/packages/stitch-models/src/stitch/models/__init__.py b/packages/stitch-models/src/stitch/models/__init__.py index 5c1ff3e..6b07ead 100644 --- a/packages/stitch-models/src/stitch/models/__init__.py +++ b/packages/stitch-models/src/stitch/models/__init__.py @@ -13,6 +13,8 @@ "Source", "SourcePayload", "SourceRefTuple", + "EmptySourcePayload", + "BaseResource", ] @@ -64,3 +66,12 @@ def _no_self_reference(self) -> Self: if self.repointed_to == self.id: raise ValueError("A resource cannot be repointed to itself") return self + + +class EmptySourcePayload(SourcePayload): + """Domain-agnostic source payload container (no sources).""" + + +# A concrete, domain-agnostic Resource you can use everywhere. +# Domains can replace `EmptySourcePayload` with their own payload type. +type BaseResource[TResId: IdType] = Resource[TResId, EmptySourcePayload] diff --git a/uv.lock b/uv.lock index 05458cc..83a5486 100644 --- a/uv.lock +++ b/uv.lock @@ -1125,6 +1125,7 @@ dependencies = [ { name = "pydantic-settings" }, { name = "sqlalchemy" }, { name = "stitch-auth" }, + { name = "stitch-models" }, ] [package.dev-dependencies] @@ -1143,6 +1144,7 @@ requires-dist = [ { name = "pydantic-settings", specifier = ">=2.12.0" }, { name = "sqlalchemy", specifier = ">=2.0.44" }, { name = "stitch-auth", editable = "packages/stitch-auth" }, + { name = "stitch-models", editable = "packages/stitch-models" }, ] [package.metadata.requires-dev]