diff --git a/deployments/api/pyproject.toml b/deployments/api/pyproject.toml index 3aed224..58058c3 100644 --- a/deployments/api/pyproject.toml +++ b/deployments/api/pyproject.toml @@ -12,6 +12,8 @@ dependencies = [ "pydantic-settings>=2.12.0", "sqlalchemy>=2.0.44", "stitch-auth", + "stitch-models", + "stitch-ogsi", ] [project.scripts] @@ -41,3 +43,5 @@ addopts = ["-v", "--strict-markers", "--tb=short"] [tool.uv.sources] stitch-auth = { workspace = true } +stitch-models = { workspace = true } +stitch-ogsi = { workspace = true } diff --git a/deployments/api/src/stitch/api/db/init_job.py b/deployments/api/src/stitch/api/db/init_job.py index 5a256c2..45ffdf3 100644 --- a/deployments/api/src/stitch/api/db/init_job.py +++ b/deployments/api/src/stitch/api/db/init_job.py @@ -5,29 +5,26 @@ import time from enum import Enum from dataclasses import dataclass -from typing import Iterable +from typing import Any from sqlalchemy import create_engine, inspect, text from sqlalchemy.exc import OperationalError from sqlalchemy.orm import Session from stitch.api.db.model import ( - CCReservoirsSourceModel, - GemSourceModel, - MembershipModel, - RMIManualSourceModel, ResourceModel, + MembershipModel, StitchBase, UserModel, - WMSourceModel, + OilGasFieldSourceModel, ) from stitch.api.entities import ( - GemData, - RMIManualData, User as UserEntity, - WMData, ) +# Domain model from stitch-ogsi package +from stitch.ogsi.model.og_field import OilGasFieldBase + """ DB init/seed job. @@ -257,7 +254,6 @@ def fail_partial(existing_tables: set[str], expected: set[str]) -> None: def create_seed_user() -> UserModel: return UserModel( - id=1, sub="seed|system", name="Seed User", email="seed@example.com", @@ -266,97 +262,79 @@ def create_seed_user() -> UserModel: def create_dev_user() -> UserModel: return UserModel( - id=2, sub="dev|local-placeholder", name="Dev Deverson", email="dev@example.com", ) -def create_seed_sources(): - gem_sources = [ - GemSourceModel.from_entity( - GemData(name="Permian Basin Field", country="USA", lat=31.8, lon=-102.3) - ), - GemSourceModel.from_entity( - GemData(name="North Sea Platform", country="GBR", lat=57.5, lon=1.5) - ), - ] - for i, src in enumerate(gem_sources, start=1): - src.id = i - - wm_sources = [ - WMSourceModel.from_entity( - WMData( - field_name="Eagle Ford Shale", field_country="USA", production=125000.5 - ) - ), - WMSourceModel.from_entity( - WMData(field_name="Ghawar Field", field_country="SAU", production=500000.0) - ), - ] - for i, src in enumerate(wm_sources, start=1): - src.id = i - - rmi_sources = [ - RMIManualSourceModel.from_entity( - RMIManualData( - name_override="Custom Override Name", - gwp=25.5, - gor=0.45, - country="CAN", - latitude=56.7, - longitude=-111.4, - ) - ), - ] - for i, src in enumerate(rmi_sources, start=1): - src.id = i - - # CC Reservoir sources are intentionally omitted from the dev seed profile; - # the CCReservoirsSourceModel table is still created from SQLAlchemy metadata. - cc_sources: list[CCReservoirsSourceModel] = [] - - return gem_sources, wm_sources, rmi_sources, cc_sources - - def create_seed_resources(user: UserEntity) -> list[ResourceModel]: resources = [ - ResourceModel.create(user, name="Multi-Source Asset", country="USA"), - ResourceModel.create(user, name="Single Source Asset", country="GBR"), + ResourceModel.create(user, name="Resource Foo01"), + ResourceModel.create(user, name="Resource Bar01"), ] - for i, res in enumerate(resources, start=1): - res.id = i return resources def create_seed_memberships( user: UserEntity, resources: list[ResourceModel], - gem_sources: list[GemSourceModel], - wm_sources: list[WMSourceModel], - rmi_sources: list[RMIManualSourceModel], + sources: list[OilGasFieldSourceModel], ) -> list[MembershipModel]: memberships = [ - MembershipModel.create(user, resources[0], "gem", gem_sources[0].id), - MembershipModel.create(user, resources[0], "wm", wm_sources[0].id), - MembershipModel.create(user, resources[0], "rmi", rmi_sources[0].id), - MembershipModel.create(user, resources[1], "gem", gem_sources[1].id), + MembershipModel.create(user, resources[0], "gem", 1), + MembershipModel.create(user, resources[1], "wm", 2), ] for i, mem in enumerate(memberships, start=1): mem.id = i return memberships -def reset_sequences(engine, tables: Iterable[str]) -> None: - with engine.begin() as conn: - for t in tables: - conn.execute( - text( - f"SELECT setval('{t}_id_seq', " - f"(SELECT COALESCE(MAX(id), 0) + 1 FROM {t}), false);" - ) - ) +def create_seed_oil_gas_source_fields( + user: UserEntity, + resources: list[ResourceModel], +) -> list[OilGasFieldSourceModel]: + """Create example OilGasField rows linked 1:1 with seeded resources.""" + + raw_payloads: list[dict[str, Any]] = [ + # pretend this came from some upstream system (GEM/WM/etc) + { + "name": "Permian Alpha", + "country": "USA", + "basin": "Permian", + # extra keys demonstrate why we keep original_payload + "upstream_id": "seed-gem-0001", + "notes": "seed example", + }, + { + "name": "North Sea Bravo", + "country": "GBR", + "basin": "North Sea", + "upstream_id": "seed-wm-0002", + "notes": "seed example", + }, + ] + + og_models: list[OilGasFieldSourceModel] = [] + + for resource, raw in zip(resources, raw_payloads): + domain = OilGasFieldBase.model_validate(raw) + model = OilGasFieldSourceModel( + created_by_id=user.id, + last_updated_by_id=user.id, + ) + # Raw input (includes extra fields not in OilGasFieldBase) + model.original_payload = raw + # Canonical validated payload + model.payload = raw + model.name = domain.name + model.country = domain.country + model.basin = domain.basin + model.source = "dev-seed" + # Populate domain columns for queryability + og_models.append(model) + + return og_models def seed_dev(engine) -> None: @@ -376,39 +354,19 @@ def seed_dev(engine) -> None: name=user_model.name, ) - dev_entity = UserEntity( - id=dev_model.id, - sub=dev_model.sub, - email=dev_model.email, - name=dev_model.name, - ) - - gem_sources, wm_sources, rmi_sources, cc_sources = create_seed_sources() - session.add_all(gem_sources + wm_sources + rmi_sources + cc_sources) - resources = create_seed_resources(user_entity) - resources = create_seed_resources(dev_entity) session.add_all(resources) + session.flush() + # + # Add sample OilGasField rows for the first two resources only + og_fields = create_seed_oil_gas_source_fields(user_entity, resources) + session.add_all(og_fields) - memberships = create_seed_memberships( - user_entity, resources, gem_sources, wm_sources, rmi_sources - ) + memberships = create_seed_memberships(user_entity, resources, og_fields) session.add_all(memberships) session.commit() - reset_sequences( - engine, - tables=[ - "users", - "gem_sources", - "wm_sources", - "rmi_manual_sources", - "resources", - "memberships", - ], - ) - def seed(engine, profile: SeedProfile | str) -> None: if profile == "dev": diff --git a/deployments/api/src/stitch/api/db/model/__init__.py b/deployments/api/src/stitch/api/db/model/__init__.py index d3f74ab..d81c8af 100644 --- a/deployments/api/src/stitch/api/db/model/__init__.py +++ b/deployments/api/src/stitch/api/db/model/__init__.py @@ -1,21 +1,13 @@ from .common import Base as StitchBase -from .sources import ( - GemSourceModel, - RMIManualSourceModel, - CCReservoirsSourceModel, - WMSourceModel, -) +from .oil_gas_field_source import OilGasFieldSourceModel from .resource import MembershipStatus, MembershipModel, ResourceModel from .user import User as UserModel __all__ = [ - "CCReservoirsSourceModel", - "GemSourceModel", "MembershipModel", "MembershipStatus", - "RMIManualSourceModel", "ResourceModel", "StitchBase", "UserModel", - "WMSourceModel", + "OilGasFieldSourceModel", ] diff --git a/deployments/api/src/stitch/api/db/model/oil_gas_field_source.py b/deployments/api/src/stitch/api/db/model/oil_gas_field_source.py new file mode 100644 index 0000000..b856e7c --- /dev/null +++ b/deployments/api/src/stitch/api/db/model/oil_gas_field_source.py @@ -0,0 +1,37 @@ +from __future__ import annotations + +from typing import Any + +from sqlalchemy import ( + Integer, + String, + JSON, +) +from sqlalchemy.orm import Mapped, mapped_column + +from .common import Base +from .mixins import TimestampMixin, UserAuditMixin + + +class OilGasFieldSourceModel(TimestampMixin, UserAuditMixin, Base): + """A single OG field source record (canonicalized), feedable into a Resource.""" + + __tablename__ = "oil_gas_field_source" + + id: Mapped[int] = mapped_column(Integer, primary_key=True) + + source: Mapped[str | None] = mapped_column(String, nullable=True) + + # Flat domain columns for filtering, indexing, query, etc. + name: Mapped[str | None] = mapped_column(String, nullable=True) + country: Mapped[str | None] = mapped_column(String, nullable=True) + basin: Mapped[str | None] = mapped_column(String, nullable=True) + + # full normalized domain payload + payload: Mapped[dict[str, Any]] = mapped_column(JSON, nullable=False) + + # original raw payload as given by client + original_payload: Mapped[dict[str, Any]] = mapped_column(JSON, nullable=False) + + # optionally track an external ref if useful + source_ref: Mapped[str | None] = mapped_column(String, nullable=True) diff --git a/deployments/api/src/stitch/api/db/model/resource.py b/deployments/api/src/stitch/api/db/model/resource.py index 5493faa..a366fdd 100644 --- a/deployments/api/src/stitch/api/db/model/resource.py +++ b/deployments/api/src/stitch/api/db/model/resource.py @@ -18,7 +18,9 @@ SourceModel, SourceModelData, ) -from stitch.api.entities import IdType, User as UserEntity + +from stitch.models.types import IdType +from stitch.api.entities import User as UserEntity from .common import Base from .mixins import TimestampMixin, UserAuditMixin from .types import PORTABLE_BIGINT @@ -92,7 +94,6 @@ class ResourceModel(TimestampMixin, UserAuditMixin, Base): PORTABLE_BIGINT, ForeignKey("resources.id"), nullable=True ) name: Mapped[str | None] = mapped_column(String, nullable=True) - country: Mapped[str | None] = mapped_column(String(3), nullable=True) # SQLAlchemy will automatically see the foreign key `memberships.resource_id` # and configure the appropriate SQL statement to load the membership objects @@ -130,12 +131,10 @@ def create( cls, created_by: UserEntity, name: str | None = None, - country: str | None = None, repointed_to: int | None = None, ): return cls( name=name, - country=country, repointed_id=repointed_to, created_by_id=created_by.id, last_updated_by_id=created_by.id, diff --git a/deployments/api/src/stitch/api/db/model/sources.py b/deployments/api/src/stitch/api/db/model/sources.py index 6b8e5af..da1dede 100644 --- a/deployments/api/src/stitch/api/db/model/sources.py +++ b/deployments/api/src/stitch/api/db/model/sources.py @@ -6,20 +6,13 @@ from sqlalchemy import CheckConstraint, inspect from sqlalchemy.orm import Mapped, mapped_column from .common import Base -from .types import PORTABLE_BIGINT, StitchJson +from .types import PORTABLE_BIGINT from stitch.api.entities import ( - CCReservoirsSource, - GemSource, IdType, - RMIManualSource, SourceKey, - WMData, - GemData, - RMIManualData, - CCReservoirsData, - WMSource, ) +from .oil_gas_field_source import OilGasFieldSourceModel def float_constraint( colname: str, min_: float | None = None, max_: float | None = None @@ -74,68 +67,22 @@ def from_entity(cls, entity: TModelIn) -> Self: return cls(**filtered) -class GemSourceModel(SourceBase[GemData, GemSource]): - __tablename__ = "gem_sources" - - name: Mapped[str] - country: Mapped[str] - lat: Mapped[float] = mapped_column(lat_constraints("lat")) - lon: Mapped[float] = mapped_column(lon_constraints("lon")) - - -class WMSourceModel(SourceBase[WMData, WMSource]): - __tablename__ = "wm_sources" - - field_name: Mapped[str] - field_country: Mapped[str] - production: Mapped[float] - - -class RMIManualSourceModel(SourceBase[RMIManualData, RMIManualSource]): - __tablename__ = "rmi_manual_sources" - - name_override: Mapped[str | None] - gwp: Mapped[float | None] - gor: Mapped[float | None | None] = mapped_column( - float_constraint("gor", 0, 1), nullable=True - ) - country: Mapped[str | None] - latitude: Mapped[float | None] = mapped_column( - lat_constraints("latitude"), nullable=True - ) - longitude: Mapped[float | None] = mapped_column( - lon_constraints("longitude"), nullable=True - ) - - -class CCReservoirsSourceModel(SourceBase[CCReservoirsData, CCReservoirsSource]): - __tablename__ = "cc_reservoirs_sources" - - name: Mapped[str] - basin: Mapped[str] - depth: Mapped[float] - geofence: Mapped[list[tuple[float, float]]] = mapped_column(StitchJson()) - - SourceModel = ( - GemSourceModel | WMSourceModel | RMIManualSourceModel | CCReservoirsSourceModel + OilGasFieldSourceModel ) SourceModelCls = type[SourceModel] SOURCE_TABLES: Final[Mapping[SourceKey, SourceModelCls]] = { - "gem": GemSourceModel, - "wm": WMSourceModel, - "rmi": RMIManualSourceModel, - "cc": CCReservoirsSourceModel, + "gem": OilGasFieldSourceModel, + "wm": OilGasFieldSourceModel, + "rmi": OilGasFieldSourceModel, + "llm": OilGasFieldSourceModel, } class SourceModelData(TypedDict, total=False): - gem: MutableMapping[IdType, GemSourceModel] - wm: MutableMapping[IdType, WMSourceModel] - cc: MutableMapping[IdType, CCReservoirsSourceModel] - rmi: MutableMapping[IdType, RMIManualSourceModel] + og_field: MutableMapping[IdType, OilGasFieldSourceModel] def empty_source_model_data(): - return SourceModelData(gem={}, wm={}, cc={}, rmi={}) + return SourceModelData(og_field={}) diff --git a/deployments/api/src/stitch/api/db/og_field_source_actions.py b/deployments/api/src/stitch/api/db/og_field_source_actions.py new file mode 100644 index 0000000..acd19f5 --- /dev/null +++ b/deployments/api/src/stitch/api/db/og_field_source_actions.py @@ -0,0 +1,78 @@ +import asyncio + +from functools import partial + +from fastapi import HTTPException +from starlette.status import HTTP_404_NOT_FOUND +from sqlalchemy import select +from sqlalchemy.orm import selectinload + +from stitch.ogsi.model.og_field import OilGasFieldBase + +from .model import OilGasFieldSourceModel, ResourceModel, MembershipModel +from .resource_actions import resource_model_to_entity + + +async def create_source( + session, + raw_payload: dict[str, object], + *, + source_system: str | None = None, + source_ref: str | None = None, +) -> OilGasFieldSourceModel: + """Validate raw JSON into domain model, persist canonical + original.""" + + # domain validation (pydantic) + domain: OilGasFieldBase = OilGasFieldBase.model_validate(raw_payload) + + model = OilGasFieldSourceModel( + source_system=source_system, + source_ref=source_ref, + name=domain.name, + country=domain.country, + basin=domain.basin, + payload=domain.model_dump(), + original_payload=raw_payload, + ) + session.add(model) + await session.flush() + return model + + +async def attach_to_resource( + session, + resource_id: int, + source_row: OilGasFieldSourceModel, + created_by, +): + """Link an OG field source to a resource via membership.""" + session.add( + MembershipModel.create( + created_by=created_by, + resource=session.get(ResourceModel, resource_id), + source="rmi", + source_pk=source_row.id, + ) + ) + await session.flush() + + +async def get_source(session, id: int) -> OilGasFieldSourceModel: + model = await session.get(OilGasFieldSourceModel, id) + if model is None: + raise HTTPException( + status_code=HTTP_404_NOT_FOUND, detail=f"No OG field source with id `{id}`" + ) + return model + +async def list_og_resources(session): + stmt = ( + select(ResourceModel) + .where(ResourceModel.repointed_id.is_(None)) + .join(MembershipModel, MembershipModel.resource_id == ResourceModel.id) + .options(selectinload(ResourceModel.memberships)) + .distinct() + ) + models = (await session.scalars(stmt)).all() + fn = partial(resource_model_to_entity, session) + return await asyncio.gather(*[fn(m) for m in models]) diff --git a/deployments/api/src/stitch/api/db/resource_actions.py b/deployments/api/src/stitch/api/db/resource_actions.py index c0d19bc..1fdabf0 100644 --- a/deployments/api/src/stitch/api/db/resource_actions.py +++ b/deployments/api/src/stitch/api/db/resource_actions.py @@ -20,13 +20,10 @@ ) from .model import ( - CCReservoirsSourceModel, - GemSourceModel, MembershipModel, - RMIManualSourceModel, ResourceModel, - WMSourceModel, ) +from stitch.ogsi.model.og_field import OilGasFieldBase async def get_or_create_source_models( @@ -50,7 +47,6 @@ def resource_model_to_empty_entity(model: ResourceModel): return Resource( id=model.id, name=model.name, - country=model.country, source_data=SourceData(), constituents=[], created=model.created, @@ -74,7 +70,6 @@ async def resource_model_to_entity( return Resource( id=model.id, name=model.name, - country=model.country, source_data=source_data, constituents=constituents, created=model.created, @@ -113,7 +108,7 @@ async def create(session: AsyncSession, user: CurrentUser, resource: CreateResou - create membership """ model = ResourceModel.create( - created_by=user, name=resource.name, country=resource.country + created_by=user, name=resource.name ) session.add(model) if resource.source_data: @@ -141,16 +136,10 @@ async def create_source_data(session: AsyncSession, data: CreateSourceData): """ For bulk inserting data into source tables. """ - gems = tuple(GemSourceModel.from_entity(gem) for gem in data.gem) - wms = tuple(WMSourceModel.from_entity(wm) for wm in data.wm) - rmis = tuple(RMIManualSourceModel.from_entity(rmi) for rmi in data.rmi) - ccs = tuple(CCReservoirsSourceModel.from_entity(cc) for cc in data.cc) + og_fields = tuple(OilGasFieldBase.from_entity(og_field) for og_field in data.og_field) - session.add_all(gems + wms + rmis + ccs) + session.add_all(og_fields) await session.flush() return SourceData( - gem={g.id: g.as_entity() for g in gems}, - wm={wm.id: wm.as_entity() for wm in wms}, - rmi={rmi.id: rmi.as_entity() for rmi in rmis}, - cc={cc.id: cc.as_entity() for cc in ccs}, + og_field={g.id: g.as_entity() for g in og_fields}, ) diff --git a/deployments/api/src/stitch/api/entities.py b/deployments/api/src/stitch/api/entities.py index 337a5ee..f15b418 100644 --- a/deployments/api/src/stitch/api/entities.py +++ b/deployments/api/src/stitch/api/entities.py @@ -1,16 +1,17 @@ from collections.abc import Sequence from datetime import datetime from typing import ( - Annotated, Generic, - Literal, Mapping, Protocol, TypeVar, runtime_checkable, ) from uuid import UUID -from pydantic import BaseModel, ConfigDict, EmailStr, Field +from pydantic import BaseModel, EmailStr, Field + +from stitch.ogsi.model.types import OGSISrcKey +from stitch.ogsi.model.og_field import OilGasFieldBase IdType = int | str | UUID @@ -21,13 +22,7 @@ class HasId(Protocol): def id(self) -> IdType: ... -GEM_SRC = Literal["gem"] -WM_SRC = Literal["wm"] -RMI_SRC = Literal["rmi"] -CC_SRC = Literal["cc"] - -SourceKey = GEM_SRC | WM_SRC | RMI_SRC | CC_SRC - +SourceKey = OGSISrcKey TSourceKey = TypeVar("TSourceKey", bound=SourceKey) @@ -57,69 +52,12 @@ class SourceRef(BaseModel): # When pulling into the internal "sources" table, each will get a new unique id which is what the memberships will reference -class GemData(BaseModel): - name: str - lat: float = Field(ge=-90, le=90) - lon: float = Field(ge=-180, le=180) - country: str - - -class GemSource(Identified, GemData): - model_config = ConfigDict(from_attributes=True) - - -class WMData(BaseModel): - field_name: str - field_country: str - production: float - - -class WMSource(Identified, WMData): - model_config = ConfigDict(from_attributes=True) - - -class RMIManualData(BaseModel): - name_override: str - gwp: float - gor: float = Field(gt=0, lt=1) - country: str - latitude: float = Field(ge=-90, le=90) - longitude: float = Field(ge=-180, le=180) - - -class RMIManualSource(Identified, RMIManualData): - model_config = ConfigDict(from_attributes=True) - - -class CCReservoirsData(BaseModel): - name: str - basin: str - depth: float - geofence: Sequence[tuple[float, float]] - - -class CCReservoirsSource(Identified, CCReservoirsData): - model_config = ConfigDict(from_attributes=True) - - -OGSISourcePayload = Annotated[ - GemSource | WMSource | RMIManualSource | CCReservoirsSource, - Field(discriminator="source"), -] - - class SourceData(BaseModel): - gem: Mapping[IdType, GemSource] = Field(default_factory=dict) - wm: Mapping[IdType, WMSource] = Field(default_factory=dict) - rmi: Mapping[IdType, RMIManualSource] = Field(default_factory=dict) - cc: Mapping[IdType, CCReservoirsSource] = Field(default_factory=dict) + og_field: Mapping[IdType, OilGasFieldBase] = Field(default_factory=dict) class CreateSourceData(BaseModel): - gem: Sequence[GemData] = Field(default_factory=list) - wm: Sequence[WMData] = Field(default_factory=list) - rmi: Sequence[RMIManualData] = Field(default_factory=list) - cc: Sequence[CCReservoirsData] = Field(default_factory=list) + og_field: Sequence[OilGasFieldBase] = Field(default_factory=list) class CreateResourceSourceData(BaseModel): @@ -129,20 +67,9 @@ class CreateResourceSourceData(BaseModel): memberships to the resource. """ - gem: Sequence[GemData | int] = Field(default_factory=list) - wm: Sequence[WMData | int] = Field(default_factory=list) - rmi: Sequence[RMIManualData | int] = Field(default_factory=list) - cc: Sequence[CCReservoirsData | int] = Field(default_factory=list) + og_field: Sequence[OilGasFieldBase | int] = Field(default_factory=list) def get(self, key: SourceKey): - if key == "gem": - return self.gem - elif key == "wm": - return self.wm - elif key == "rmi": - return self.rmi - elif key == "cc": - return self.cc raise ValueError(f"Unknown source key: {key}") diff --git a/deployments/api/src/stitch/api/main.py b/deployments/api/src/stitch/api/main.py index 0cbf561..250e342 100644 --- a/deployments/api/src/stitch/api/main.py +++ b/deployments/api/src/stitch/api/main.py @@ -8,12 +8,14 @@ from .auth import validate_auth_config_at_startup from .settings import get_settings -from .routers.resources import router as resource_router from .routers.health import router as health_router +from .routers.oil_gas_fields import router as og_router +from .routers.resources import router as resource_router base_router = APIRouter(prefix="/api/v1") -base_router.include_router(resource_router) base_router.include_router(health_router) +base_router.include_router(og_router) +base_router.include_router(resource_router) @asynccontextmanager diff --git a/deployments/api/src/stitch/api/routers/oil_gas_fields.py b/deployments/api/src/stitch/api/routers/oil_gas_fields.py new file mode 100644 index 0000000..bb8496a --- /dev/null +++ b/deployments/api/src/stitch/api/routers/oil_gas_fields.py @@ -0,0 +1,48 @@ +from __future__ import annotations + + +from fastapi import APIRouter + +from stitch.api.auth import CurrentUser +from stitch.api.db import resource_actions, og_field_source_actions +from stitch.api.db.config import UnitOfWorkDep +from stitch.api.db.og_field_source_actions import create_source, attach_to_resource +from stitch.api.entities import CreateResource, Resource + +router = APIRouter(prefix="/oil-gas-fields", tags=["oil_gas_fields"]) + + +@router.post("/", response_model=Resource) +async def create_oil_gas_field( + raw_body: dict[str, object], + uow: UnitOfWorkDep, + user: CurrentUser, +): + session = uow.session + + # 1) create a generic resource + resource = await resource_actions.create( + session=session, user=user, resource=CreateResource(name=raw_body.get("name")) + ) + + # 2) create canonical domain source + src = await create_source( + session=session, + raw_payload=raw_body, + source_system=raw_body.get("source_system"), + ) + + # 3) attach it via membership + await attach_to_resource(session, resource.id, src, user) + + return resource + + +@router.get("/", response_model=list[Resource]) +async def list_oil_gas_fields(uow: UnitOfWorkDep, user: CurrentUser): + return await og_field_source_actions.list_og_resources(session=uow.session) + + +@router.get("/{id}", response_model=Resource) +async def get_oil_gas_field(id: int, uow: UnitOfWorkDep, user: CurrentUser): + return await resource_actions.get(session=uow.session, id=id) diff --git a/deployments/api/tests/conftest.py b/deployments/api/tests/conftest.py index 41ee121..1782c11 100644 --- a/deployments/api/tests/conftest.py +++ b/deployments/api/tests/conftest.py @@ -9,24 +9,13 @@ from stitch.api.db.config import UnitOfWork, get_uow from stitch.api.db.model import ( - CCReservoirsSourceModel, - GemSourceModel, - RMIManualSourceModel, StitchBase, UserModel, - WMSourceModel, ) from stitch.api.auth import get_current_user from stitch.api.entities import User from stitch.api.main import app -from .utils import ( - CC_DEFAULTS, - GEM_DEFAULTS, - RMI_DEFAULTS, - WM_DEFAULTS, -) - @pytest.fixture def anyio_backend() -> str: @@ -184,117 +173,3 @@ def override_get_current_user() -> User: base_url="http://test/api/v1", ) as ac: yield ac - - -@pytest.fixture -async def existing_gem_source( - seeded_integration_session: AsyncSession, -) -> GemSourceModel: - """Pre-create a GEM source in DB, return model with ID.""" - model = GemSourceModel( - name=GEM_DEFAULTS["name"], - lat=GEM_DEFAULTS["lat"], - lon=GEM_DEFAULTS["lon"], - country=GEM_DEFAULTS["country"], - ) - seeded_integration_session.add(model) - await seeded_integration_session.flush() - return model - - -@pytest.fixture -async def existing_wm_source( - seeded_integration_session: AsyncSession, -) -> WMSourceModel: - """Pre-create a WM source in DB, return model with ID.""" - model = WMSourceModel( - field_name=WM_DEFAULTS["field_name"], - field_country=WM_DEFAULTS["field_country"], - production=WM_DEFAULTS["production"], - ) - seeded_integration_session.add(model) - await seeded_integration_session.flush() - return model - - -@pytest.fixture -async def existing_rmi_source( - seeded_integration_session: AsyncSession, -) -> RMIManualSourceModel: - """Pre-create an RMI source in DB, return model with ID.""" - model = RMIManualSourceModel( - name_override=RMI_DEFAULTS["name_override"], - gwp=RMI_DEFAULTS["gwp"], - gor=RMI_DEFAULTS["gor"], - country=RMI_DEFAULTS["country"], - latitude=RMI_DEFAULTS["latitude"], - longitude=RMI_DEFAULTS["longitude"], - ) - seeded_integration_session.add(model) - await seeded_integration_session.flush() - return model - - -@pytest.fixture -async def existing_cc_source( - seeded_integration_session: AsyncSession, -) -> CCReservoirsSourceModel: - """Pre-create a CC source in DB, return model with ID.""" - model = CCReservoirsSourceModel( - name=CC_DEFAULTS["name"], - basin=CC_DEFAULTS["basin"], - depth=CC_DEFAULTS["depth"], - geofence=list(CC_DEFAULTS["geofence"]), - ) - seeded_integration_session.add(model) - await seeded_integration_session.flush() - return model - - -@pytest.fixture -async def existing_sources( - seeded_integration_session: AsyncSession, -) -> dict[str, list[int]]: - """Create 2 of each source type, return dict mapping source key to list of IDs.""" - session = seeded_integration_session - - gems = [ - GemSourceModel(name=f"GEM {i}", lat=45.0 + i, lon=-120.0 + i, country="USA") - for i in range(2) - ] - wms = [ - WMSourceModel( - field_name=f"WM Field {i}", field_country="USA", production=1000.0 * (i + 1) - ) - for i in range(2) - ] - rmis = [ - RMIManualSourceModel( - name_override=f"RMI {i}", - gwp=25.0, - gor=0.5, - country="USA", - latitude=40.0 + i, - longitude=-100.0 + i, - ) - for i in range(2) - ] - ccs = [ - CCReservoirsSourceModel( - name=f"CC Reservoir {i}", - basin="Permian", - depth=3000.0, - geofence=[(0.0, 0.0), (1.0, 0.0), (1.0, 1.0), (0.0, 1.0)], - ) - for i in range(2) - ] - - session.add_all(gems + wms + rmis + ccs) - await session.flush() - - return { - "gem": [g.id for g in gems], - "wm": [w.id for w in wms], - "rmi": [r.id for r in rmis], - "cc": [c.id for c in ccs], - } diff --git a/deployments/api/tests/db/test_resource_actions.py b/deployments/api/tests/db/test_resource_actions.py index 39538a2..a2a83b1 100644 --- a/deployments/api/tests/db/test_resource_actions.py +++ b/deployments/api/tests/db/test_resource_actions.py @@ -1,52 +1,25 @@ -"""Database integration tests for resource_actions module.""" +"""Database integration tests for domain-agnostic resource_actions.""" import pytest from fastapi import HTTPException -from sqlalchemy import func, select from sqlalchemy.ext.asyncio import AsyncSession from stitch.api.db import resource_actions -from stitch.api.db.model import ( - GemSourceModel, - MembershipModel, - ResourceModel, -) -from stitch.api.entities import CreateSourceData, GemData, User, WMData - -from tests.utils import ( - make_cc_data, - make_create_resource, - make_empty_resource, - make_gem_data, - make_resource_with_existing_ids, - make_resource_with_mixed_sources, - make_resource_with_new_sources, - make_rmi_data, - make_source_data, - make_wm_data, -) - - -class TestGetResourceActionUnit: ... - - -class TestCreateResourceActionUnit: ... - - -class TestCreateSourceDataActionUnit: ... +from stitch.api.db.model import ResourceModel +from stitch.api.entities import User +from tests.utils import make_create_resource, make_empty_resource class TestCreateResourceActionIntegration: """Integration tests for resource_actions.create() with real database.""" @pytest.mark.anyio - async def test_creates_resource_with_no_source_data( + async def test_creates_resource_with_minimal_payload( self, seeded_integration_session: AsyncSession, test_user: User, ): - """Resource with no source data persists correctly.""" - resource_in = make_empty_resource(name="Empty Resource", country="USA") + resource_in = make_empty_resource(name=None) result = await resource_actions.create( session=seeded_integration_session, @@ -55,221 +28,18 @@ async def test_creates_resource_with_no_source_data( ) assert result.id is not None - assert result.name == "Empty Resource" - assert result.country == "USA" + assert result.name is None db_resource = await seeded_integration_session.get(ResourceModel, result.id) assert db_resource is not None - assert db_resource.name == "Empty Resource" - - membership_count = ( - await seeded_integration_session.execute( - select(func.count()).select_from(MembershipModel) - ) - ).scalar() - assert membership_count == 0 - - @pytest.mark.anyio - async def test_creates_resource_with_new_gem_source( - self, - seeded_integration_session: AsyncSession, - test_user: User, - ): - """New GEM source creates resource, source, and membership.""" - resource_in = make_resource_with_new_sources( - gem=make_gem_data(name="Test GEM Field", lat=40.0, lon=-100.0).model, - name="With GEM", - ) - - result = await resource_actions.create( - session=seeded_integration_session, - user=test_user, - resource=resource_in.model, - ) - - assert result.id is not None - assert result.name == "With GEM" - - gem_sources = ( - ( - await seeded_integration_session.execute( - select(GemSourceModel).where( - GemSourceModel.name == "Test GEM Field" - ) - ) - ) - .scalars() - .all() - ) - assert len(gem_sources) == 1 - - memberships = ( - ( - await seeded_integration_session.execute( - select(MembershipModel).where( - MembershipModel.resource_id == result.id - ) - ) - ) - .scalars() - .all() - ) - assert len(memberships) == 1 - assert memberships[0].source == "gem" - - @pytest.mark.anyio - async def test_creates_resource_with_new_sources_all_types( - self, - seeded_integration_session: AsyncSession, - test_user: User, - ): - """Resource with all four source types creates correct memberships.""" - source_data = make_source_data( - gem=[make_gem_data(name="All Types GEM").model], - wm=[make_wm_data(field_name="All Types WM").model], - rmi=[make_rmi_data(name_override="All Types RMI").model], - cc=[make_cc_data(name="All Types CC").model], - ) - resource_in = make_create_resource( - name="All Sources Resource", - source_data=source_data, - ) - - result = await resource_actions.create( - session=seeded_integration_session, - user=test_user, - resource=resource_in.model, - ) - - memberships = ( - ( - await seeded_integration_session.execute( - select(MembershipModel).where( - MembershipModel.resource_id == result.id - ) - ) - ) - .scalars() - .all() - ) - - sources = {m.source for m in memberships} - assert sources == {"gem", "wm", "rmi", "cc"} - - @pytest.mark.anyio - async def test_creates_resource_with_existing_gem_id( - self, - seeded_integration_session: AsyncSession, - test_user: User, - existing_gem_source: GemSourceModel, - ): - """Existing source ID creates membership without new source record.""" - resource_in = make_resource_with_existing_ids( - gem_ids=[existing_gem_source.id], - name="With Existing GEM", - ) - - result = await resource_actions.create( - session=seeded_integration_session, - user=test_user, - resource=resource_in.model, - ) - - assert result.id is not None - - memberships = ( - ( - await seeded_integration_session.execute( - select(MembershipModel).where( - MembershipModel.resource_id == result.id - ) - ) - ) - .scalars() - .all() - ) - assert len(memberships) == 1 - assert memberships[0].source_pk == existing_gem_source.id - - @pytest.mark.anyio - async def test_creates_resource_with_mixed_new_and_existing( - self, - seeded_integration_session: AsyncSession, - test_user: User, - existing_gem_source: GemSourceModel, - ): - """Mix of new sources and existing IDs creates correct memberships.""" - new_gem = make_gem_data(name="Brand New GEM").model - - resource_in = make_resource_with_mixed_sources( - new_gem=new_gem, - existing_gem_ids=[existing_gem_source.id], - name="Mixed Sources", - ) - - result = await resource_actions.create( - session=seeded_integration_session, - user=test_user, - resource=resource_in.model, - ) - - memberships = ( - ( - await seeded_integration_session.execute( - select(MembershipModel).where( - MembershipModel.resource_id == result.id - ) - ) - ) - .scalars() - .all() - ) - - gem_memberships = [m for m in memberships if m.source == "gem"] - assert len(gem_memberships) == 2 - - @pytest.mark.anyio - async def test_creates_resource_with_multiple_sources_same_type( - self, - seeded_integration_session: AsyncSession, - test_user: User, - ): - """Multiple sources of same type creates multiple memberships.""" - gems = [make_gem_data(name=f"Multi GEM {i}").model for i in range(3)] - resource_in = make_resource_with_new_sources(gem=gems, name="Multiple GEMs") - - result = await resource_actions.create( - session=seeded_integration_session, - user=test_user, - resource=resource_in.model, - ) - - memberships = ( - ( - await seeded_integration_session.execute( - select(MembershipModel).where( - MembershipModel.resource_id == result.id - ) - ) - ) - .scalars() - .all() - ) - - assert len(memberships) == 3 - assert all(m.source == "gem" for m in memberships) @pytest.mark.anyio - async def test_nonexistent_source_id_creates_no_membership( + async def test_creates_resource_with_label( self, seeded_integration_session: AsyncSession, test_user: User, ): - """Nonexistent source ID is skipped, no membership created.""" - resource_in = make_resource_with_existing_ids( - gem_ids=[99999], - name="With Bad ID", - ) + resource_in = make_create_resource(name="Test Label") result = await resource_actions.create( session=seeded_integration_session, @@ -278,85 +48,27 @@ async def test_nonexistent_source_id_creates_no_membership( ) assert result.id is not None + assert result.name == "Test Label" - memberships = ( - ( - await seeded_integration_session.execute( - select(MembershipModel).where( - MembershipModel.resource_id == result.id - ) - ) - ) - .scalars() - .all() - ) - assert len(memberships) == 0 - - @pytest.mark.anyio - async def test_source_can_be_linked_to_multiple_resources( - self, - seeded_integration_session: AsyncSession, - test_user: User, - existing_gem_source: GemSourceModel, - ): - """Verify many-to-many: same source record can belong to multiple resources.""" - resource1_in = make_resource_with_existing_ids( - gem_ids=[existing_gem_source.id], - name="First Resource", - ) - result1 = await resource_actions.create( - session=seeded_integration_session, - user=test_user, - resource=resource1_in.model, - ) - - resource2_in = make_resource_with_existing_ids( - gem_ids=[existing_gem_source.id], - name="Second Resource", - ) - result2 = await resource_actions.create( - session=seeded_integration_session, - user=test_user, - resource=resource2_in.model, - ) - - memberships = ( - ( - await seeded_integration_session.execute( - select(MembershipModel).where( - MembershipModel.source == "gem", - MembershipModel.source_pk == existing_gem_source.id, - ) - ) - ) - .scalars() - .all() - ) - - assert len(memberships) == 2 - resource_ids = {m.resource_id for m in memberships} - assert resource_ids == {result1.id, result2.id} + db_resource = await seeded_integration_session.get(ResourceModel, result.id) + assert db_resource is not None + # DB ResourceModel may store `.name`; tolerate either while refactor settles. + assert getattr(db_resource, "name", None) in (None, "Test Label") class TestGetResourceActionIntegration: """Integration tests for resource_actions.get() with real database.""" @pytest.mark.anyio - async def test_get_returns_resource_with_populated_source_data( + async def test_get_returns_resource( self, seeded_integration_session: AsyncSession, test_user: User, ): - """GET returns resource with source_data populated.""" - resource_in = make_resource_with_new_sources( - gem=make_gem_data(name="Get Test GEM").model, - name="Get Test", - ) - created = await resource_actions.create( session=seeded_integration_session, user=test_user, - resource=resource_in.model, + resource=make_create_resource(name="Get Test").model, ) result = await resource_actions.get( @@ -364,61 +76,46 @@ async def test_get_returns_resource_with_populated_source_data( id=created.id, ) + assert result.id == created.id assert result.name == "Get Test" - assert len(result.source_data.gem) == 1 @pytest.mark.anyio async def test_get_nonexistent_raises_404( self, seeded_integration_session: AsyncSession, ): - """GET nonexistent resource raises HTTPException with 404.""" with pytest.raises(HTTPException) as exc_info: await resource_actions.get( session=seeded_integration_session, id=99999, ) - assert exc_info.value.status_code == 404 -class TestCreateSourceDataActionIntegration: - """Integration tests for resource_actions.create_source_data().""" +class TestListResourcesActionIntegration: + """Integration tests for resource_actions.get_all() with real database.""" @pytest.mark.anyio - async def test_bulk_creates_sources_returns_source_data_with_ids( + async def test_get_all_returns_sequence( self, seeded_integration_session: AsyncSession, + test_user: User, ): - """Bulk create sources returns SourceData with assigned IDs.""" - source_data = CreateSourceData( - gem=[ - GemData(name="Bulk GEM 1", lat=40.0, lon=-100.0, country="USA"), - GemData(name="Bulk GEM 2", lat=41.0, lon=-101.0, country="CAN"), - ], - wm=[ - WMData(field_name="Bulk WM", field_country="USA", production=5000.0), - ], + # create a couple resources + await resource_actions.create( + session=seeded_integration_session, + user=test_user, + resource=make_create_resource(name="A").model, ) - - result = await resource_actions.create_source_data( + await resource_actions.create( session=seeded_integration_session, - data=source_data, + user=test_user, + resource=make_create_resource(name="B").model, ) - assert len(result.gem) == 2 - assert len(result.wm) == 1 - - for gem in result.gem.values(): - assert gem.id is not None + results = await resource_actions.get_all(session=seeded_integration_session) + assert isinstance(results, (list, tuple)) + assert len(results) >= 2 - db_gems = ( - ( - await seeded_integration_session.execute( - select(GemSourceModel).where(GemSourceModel.name.like("Bulk GEM%")) - ) - ) - .scalars() - .all() - ) - assert len(db_gems) == 2 + labels = {r.name for r in results} + assert {"A", "B"} <= labels diff --git a/deployments/api/tests/routers/test_resources_integration.py b/deployments/api/tests/routers/test_resources_integration.py index 58f2d13..cbcd64f 100644 --- a/deployments/api/tests/routers/test_resources_integration.py +++ b/deployments/api/tests/routers/test_resources_integration.py @@ -5,12 +5,7 @@ from stitch.api.db.model import ResourceModel -from tests.utils import ( - make_empty_resource, - make_gem_data, - make_resource_with_new_sources, - make_wm_data, -) +from tests.utils import make_empty_resource, make_create_resource class TestResourcesIntegration: @@ -27,31 +22,20 @@ async def test_get_nonexistent_returns_404(self, integration_client): @pytest.mark.anyio async def test_create_resource_returns_resource(self, integration_client): """POST /resources/ returns the created resource.""" - resource_in = make_resource_with_new_sources( - gem=make_gem_data(name="GEM Integration Field", lat=40.0, lon=-100.0).model, - name="Integration Test Resource", - country="USA", - ) + resource_in = make_create_resource(name="Integration Test Resource") response = await integration_client.post("/resources/", json=resource_in.data) assert response.status_code == 200 data = response.json() assert data["name"] == "Integration Test Resource" - assert data["country"] == "USA" assert "id" in data assert data["id"] > 0 @pytest.mark.anyio async def test_create_and_get_resource(self, integration_client): """POST creates resource, GET retrieves it.""" - resource_in = make_resource_with_new_sources( - wm=make_wm_data( - field_name="WM Roundtrip Field", field_country="CAN", production=5000.0 - ).model, - name="Roundtrip Resource", - country="CAN", - ) + resource_in = make_create_resource(name="Roundtrip Resource") create_response = await integration_client.post( "/resources/", json=resource_in.data @@ -66,20 +50,13 @@ async def test_create_and_get_resource(self, integration_client): data = get_response.json() assert data["id"] == created_id assert data["name"] == "Roundtrip Resource" - assert data["country"] == "CAN" @pytest.mark.anyio async def test_create_persists_to_database( self, integration_client, integration_session_factory ): """POST resource is persisted and queryable directly.""" - resource_in = make_resource_with_new_sources( - gem=make_gem_data( - name="GEM Persist Field", lat=25.0, lon=-105.0, country="MEX" - ).model, - name="Persisted Resource", - country="MEX", - ) + resource_in = make_create_resource(name="Persisted Resource") response = await integration_client.post("/resources/", json=resource_in.data) @@ -94,12 +71,11 @@ async def test_create_persists_to_database( assert resource is not None assert resource.name == "Persisted Resource" - assert resource.country == "MEX" @pytest.mark.anyio async def test_create_with_minimal_data(self, integration_client): """POST /resources/ works with only required fields (no source data).""" - resource_in = make_empty_resource(name=None, country=None) + resource_in = make_empty_resource(name=None) response = await integration_client.post("/resources/", json=resource_in.data) @@ -107,4 +83,3 @@ async def test_create_with_minimal_data(self, integration_client): data = response.json() assert data["id"] > 0 assert data["name"] is None - assert data["country"] is None diff --git a/deployments/api/tests/routers/test_resources_unit.py b/deployments/api/tests/routers/test_resources_unit.py index b8841ce..87dd4dc 100644 --- a/deployments/api/tests/routers/test_resources_unit.py +++ b/deployments/api/tests/routers/test_resources_unit.py @@ -1,6 +1,5 @@ """Unit tests for resources router with mocked dependencies.""" -from datetime import datetime, timezone from unittest.mock import AsyncMock, patch import pytest @@ -8,31 +7,22 @@ from starlette.status import HTTP_404_NOT_FOUND from stitch.api.db.config import get_uow -from stitch.api.entities import Resource, SourceData from stitch.api.main import app -from tests.utils import ( - make_gem_data, - make_resource_with_new_sources, - make_wm_data, -) +from tests.utils import make_create_resource + +# Import the response model used by the router (domain-agnostic). +from stitch.api.resources.entities import Resource def make_resource( id: int = 1, - name: str = "Test Resource", - country: str = "USA", + name: str | None = "Test Resource", ) -> Resource: """Factory for creating Resource entities for tests.""" - now = datetime.now(timezone.utc) return Resource( id=id, name=name, - country=country, - source_data=SourceData(), - constituents=[], - created=now, - updated=now, ) @@ -88,14 +78,8 @@ class TestCreateResourceUnit: @pytest.mark.anyio async def test_creates_resource_with_user(self, async_client, mock_uow, test_user): """POST /resources/ calls repo.create with user and data.""" - expected = make_resource(id=123, name="New Resource", country="CAN") - resource_in = make_resource_with_new_sources( - gem=make_gem_data( - name="GEM Field", lat=45.0, lon=-120.0, country="CAN" - ).model, - name="New Resource", - country="CAN", - ) + expected = make_resource(id=123, name="New Resource") + resource_in = make_create_resource(name="New Resource") async def override_get_uow(): yield mock_uow @@ -116,12 +100,7 @@ async def override_get_uow(): async def test_returns_created_resource(self, async_client, mock_uow): """POST /resources/ returns the created resource entity.""" expected = make_resource(id=456, name="Created Resource") - resource_in = make_resource_with_new_sources( - wm=make_wm_data( - field_name="WM Field", field_country="USA", production=1000.0 - ).model, - name="Created Resource", - ) + resource_in = make_create_resource(name="Created Resource") async def override_get_uow(): yield mock_uow @@ -147,14 +126,6 @@ async def override_get_uow(): app.dependency_overrides[get_uow] = override_get_uow - response = await async_client.post( - "/resources/", - json={ - "name": "Test Resource", - "source_data": { - "gem": [{"invalid_field": "bad"}], - }, - }, - ) + response = await async_client.post("/resources/", json={"label": 123}) assert response.status_code == 422 diff --git a/deployments/api/tests/utils.py b/deployments/api/tests/utils.py index 490383c..90a52d2 100644 --- a/deployments/api/tests/utils.py +++ b/deployments/api/tests/utils.py @@ -6,20 +6,12 @@ from __future__ import annotations -from collections.abc import Sequence from dataclasses import dataclass from typing import Any, Generic, TypeVar from pydantic import BaseModel -from stitch.api.entities import ( - CCReservoirsData, - CreateResource, - CreateResourceSourceData, - GemData, - RMIManualData, - WMData, -) +from stitch.api.resources.entities import CreateResource T = TypeVar("T", bound=BaseModel) @@ -33,241 +25,23 @@ class FactoryResult(Generic[T]): @property def data(self) -> dict[str, Any]: """Return dict representation via model_dump().""" - return self.model.model_dump() - - -# Static defaults for each source type (no id - these are for creation) -GEM_DEFAULTS: dict[str, Any] = { - "name": "Default GEM Field", - "lat": 45.0, - "lon": -120.0, - "country": "USA", -} - -WM_DEFAULTS: dict[str, Any] = { - "field_name": "Default WM Field", - "field_country": "USA", - "production": 1000.0, -} - -RMI_DEFAULTS: dict[str, Any] = { - "name_override": "Default RMI", - "gwp": 25.0, - "gor": 0.5, - "country": "USA", - "latitude": 40.0, - "longitude": -100.0, -} - -CC_DEFAULTS: dict[str, Any] = { - "name": "Default CC Reservoir", - "basin": "Permian", - "depth": 3000.0, - "geofence": [(0.0, 0.0), (1.0, 0.0), (1.0, 1.0), (0.0, 1.0)], -} - - -def make_gem_data( - name: str = GEM_DEFAULTS["name"], - lat: float = GEM_DEFAULTS["lat"], - lon: float = GEM_DEFAULTS["lon"], - country: str = GEM_DEFAULTS["country"], -) -> FactoryResult[GemData]: - """Create a GemData with both model and dict representations.""" - return FactoryResult(model=GemData(name=name, lat=lat, lon=lon, country=country)) - - -def make_wm_data( - field_name: str = WM_DEFAULTS["field_name"], - field_country: str = WM_DEFAULTS["field_country"], - production: float = WM_DEFAULTS["production"], -) -> FactoryResult[WMData]: - """Create a WMData with both model and dict representations.""" - return FactoryResult( - model=WMData( - field_name=field_name, field_country=field_country, production=production - ) - ) - - -def make_rmi_data( - name_override: str = RMI_DEFAULTS["name_override"], - gwp: float = RMI_DEFAULTS["gwp"], - gor: float = RMI_DEFAULTS["gor"], - country: str = RMI_DEFAULTS["country"], - latitude: float = RMI_DEFAULTS["latitude"], - longitude: float = RMI_DEFAULTS["longitude"], -) -> FactoryResult[RMIManualData]: - """Create an RMIManualData with both model and dict representations.""" - return FactoryResult( - model=RMIManualData( - name_override=name_override, - gwp=gwp, - gor=gor, - country=country, - latitude=latitude, - longitude=longitude, - ) - ) - - -def make_cc_data( - name: str = CC_DEFAULTS["name"], - basin: str = CC_DEFAULTS["basin"], - depth: float = CC_DEFAULTS["depth"], - geofence: Sequence[tuple[float, float]] = CC_DEFAULTS["geofence"], -) -> FactoryResult[CCReservoirsData]: - """Create a CCReservoirsData with both model and dict representations.""" - return FactoryResult( - model=CCReservoirsData( - name=name, basin=basin, depth=depth, geofence=list(geofence) - ) - ) - - -def make_source_data( - gem: Sequence[GemData | int] | None = None, - wm: Sequence[WMData | int] | None = None, - rmi: Sequence[RMIManualData | int] | None = None, - cc: Sequence[CCReservoirsData | int] | None = None, -) -> FactoryResult[CreateResourceSourceData]: - """Create CreateResourceSourceData with both model and dict representations. - - Args: - gem: List of GemData models or existing source IDs - wm: List of WMData models or existing source IDs - rmi: List of RMIManualData models or existing source IDs - cc: List of CCReservoirsData models or existing source IDs - - Returns: - FactoryResult with model and data (dict) attributes - """ - return FactoryResult( - model=CreateResourceSourceData( - gem=list(gem or []), - wm=list(wm or []), - rmi=list(rmi or []), - cc=list(cc or []), - ) - ) + return self.model.model_dump(mode="json") def make_create_resource( - name: str | None = "Test Resource", - country: str | None = "USA", - source_data: CreateResourceSourceData - | FactoryResult[CreateResourceSourceData] - | None = None, + *, + name: str | None = None, ) -> FactoryResult[CreateResource]: - """Create a CreateResource with both model and dict representations. - - Args: - name: Resource name (optional) - country: Country code (optional) - source_data: Either a CreateResourceSourceData model, a FactoryResult, - or None for empty source data - - Returns: - FactoryResult with model and data (dict) attributes - """ - if source_data is None: - sd_model = None - elif isinstance(source_data, FactoryResult): - sd_model = source_data.model - else: - sd_model = source_data - - return FactoryResult( - model=CreateResource(name=name, country=country, source_data=sd_model) - ) + """Create a minimal, domain-agnostic CreateResource payload.""" + return FactoryResult(model=CreateResource(name=name)) # Convenience factory functions for common test scenarios def make_empty_resource( - name: str | None = "Empty Resource", - country: str | None = "USA", + *, + name: str | None = None, ) -> FactoryResult[CreateResource]: - """Create a resource with no source data.""" - return make_create_resource(name=name, country=country, source_data=None) - - -def make_resource_with_new_sources( - gem: GemData | Sequence[GemData] | None = None, - wm: WMData | Sequence[WMData] | None = None, - rmi: RMIManualData | Sequence[RMIManualData] | None = None, - cc: CCReservoirsData | Sequence[CCReservoirsData] | None = None, - name: str | None = "Resource with Sources", - country: str | None = "USA", -) -> FactoryResult[CreateResource]: - """Create a resource with new source data only (no existing IDs).""" - - def to_list(item: Any | Sequence[Any] | None) -> list[Any]: - if item is None: - return [] - if isinstance(item, (list, tuple)): - return list(item) - return [item] - - source_data = make_source_data( - gem=to_list(gem), - wm=to_list(wm), - rmi=to_list(rmi), - cc=to_list(cc), - ) - return make_create_resource(name=name, country=country, source_data=source_data) - - -def make_resource_with_existing_ids( - gem_ids: Sequence[int] | None = None, - wm_ids: Sequence[int] | None = None, - rmi_ids: Sequence[int] | None = None, - cc_ids: Sequence[int] | None = None, - name: str | None = "Resource with Existing Sources", - country: str | None = "USA", -) -> FactoryResult[CreateResource]: - """Create a resource referencing existing source IDs only.""" - source_data = make_source_data( - gem=list(gem_ids or []), - wm=list(wm_ids or []), - rmi=list(rmi_ids or []), - cc=list(cc_ids or []), - ) - return make_create_resource(name=name, country=country, source_data=source_data) - - -def make_resource_with_mixed_sources( - new_gem: GemData | Sequence[GemData] | None = None, - existing_gem_ids: Sequence[int] | None = None, - new_wm: WMData | Sequence[WMData] | None = None, - existing_wm_ids: Sequence[int] | None = None, - new_rmi: RMIManualData | Sequence[RMIManualData] | None = None, - existing_rmi_ids: Sequence[int] | None = None, - new_cc: CCReservoirsData | Sequence[CCReservoirsData] | None = None, - existing_cc_ids: Sequence[int] | None = None, - name: str | None = "Resource with Mixed Sources", - country: str | None = "USA", -) -> FactoryResult[CreateResource]: - """Create a resource with a mix of new source data and existing source IDs.""" - - def to_list(item: Any | Sequence[Any] | None) -> list[Any]: - if item is None: - return [] - if isinstance(item, (list, tuple)): - return list(item) - return [item] - - gem_items: list[GemData | int] = to_list(new_gem) + list(existing_gem_ids or []) - wm_items: list[WMData | int] = to_list(new_wm) + list(existing_wm_ids or []) - rmi_items: list[RMIManualData | int] = to_list(new_rmi) + list( - existing_rmi_ids or [] - ) - cc_items: list[CCReservoirsData | int] = to_list(new_cc) + list( - existing_cc_ids or [] - ) - - source_data = make_source_data( - gem=gem_items, wm=wm_items, rmi=rmi_items, cc=cc_items - ) - return make_create_resource(name=name, country=country, source_data=source_data) + """Alias for make_create_resource() kept for readability.""" + return make_create_resource(name=name) diff --git a/deployments/stitch-frontend/src/App.jsx b/deployments/stitch-frontend/src/App.jsx index 0568759..098722d 100644 --- a/deployments/stitch-frontend/src/App.jsx +++ b/deployments/stitch-frontend/src/App.jsx @@ -1,5 +1,5 @@ -import ResourcesView from "./components/ResourcesView"; -import ResourceView from "./components/ResourceView"; +import OGFieldsView from "./components/OGFieldsView"; +import OGFieldView from "./components/OGFieldView"; import { LogoutButton } from "./components/LogoutButton"; function App() { @@ -8,8 +8,8 @@ function App() {
- - + + ); } diff --git a/deployments/stitch-frontend/src/App.test.jsx b/deployments/stitch-frontend/src/App.test.jsx index 19496cd..4cf535e 100644 --- a/deployments/stitch-frontend/src/App.test.jsx +++ b/deployments/stitch-frontend/src/App.test.jsx @@ -4,25 +4,25 @@ import { renderWithQueryClient } from "./test/utils"; import App from "./App"; describe("App", () => { - it("renders Resources heading", () => { + it("renders OGFields heading", () => { renderWithQueryClient(); - const heading = screen.getByText(/^Resources$/i); + const heading = screen.getByText(/^OGFields$/i); expect(heading).toBeInTheDocument(); }); - it("renders Resource heading", () => { + it("renders OGField heading", () => { renderWithQueryClient(); - const heading = screen.getByText(/^Resource ID: \d+$/i); + const heading = screen.getByText(/^OGField ID: \d+$/i); expect(heading).toBeInTheDocument(); }); - it("renders both ResourcesView and ResourceView components", () => { + it("renders both OGFieldsView and OGFieldView components", () => { renderWithQueryClient(); - // Check for ResourcesView content - expect(screen.getByText(/^Resources$/i)).toBeInTheDocument(); + // Check for OGFieldsView content + expect(screen.getByText(/^OGFields$/i)).toBeInTheDocument(); - // Check for ResourceView content - expect(screen.getByText(/^Resource ID: \d+$/i)).toBeInTheDocument(); + // Check for OGFieldView content + expect(screen.getByText(/^OGField ID: \d+$/i)).toBeInTheDocument(); }); }); diff --git a/deployments/stitch-frontend/src/components/OGFieldView.jsx b/deployments/stitch-frontend/src/components/OGFieldView.jsx new file mode 100644 index 0000000..a722cf2 --- /dev/null +++ b/deployments/stitch-frontend/src/components/OGFieldView.jsx @@ -0,0 +1,62 @@ +import { useState } from "react"; +import { useQueryClient } from "@tanstack/react-query"; +import { useOGField } from "../hooks/useOGFields"; +import FetchButton from "./FetchButton"; +import ClearCacheButton from "./ClearCacheButton"; +import JsonView from "./JsonView"; +import Input from "./Input"; +import { ogfieldKeys } from "../queries/ogfields"; +import config from "../config/env"; + +export default function OGFieldView({ className, endpoint }) { + const queryClient = useQueryClient(); + const [id, setId] = useState(1); + const { data, isLoading, isError, error, refetch } = useOGField(id); + + const handleClear = (id) => { + queryClient.resetQueries({ queryKey: ogfieldKeys.detail(id) }); + }; + + const handleKeyDown = (e) => { + if (e.key === "Enter") { + refetch(); + } + }; + + return ( +
+

+ OGField ID: {id} +

+
+ + {config.apiBaseUrl} + {endpoint} + +
+
+ setId(Number(e.target.value))} + onKeyDown={handleKeyDown} + min={1} + max={1000} + className="w-24" + /> + refetch()} isLoading={isLoading} /> + handleClear(id)} + disabled={!data && !error} + /> +
+ +
+ ); +} diff --git a/deployments/stitch-frontend/src/components/OGFieldsList.jsx b/deployments/stitch-frontend/src/components/OGFieldsList.jsx new file mode 100644 index 0000000..00e2836 --- /dev/null +++ b/deployments/stitch-frontend/src/components/OGFieldsList.jsx @@ -0,0 +1,50 @@ +import Card from "./Card"; + +function OGFieldsList({ ogfields, isLoading, isError, error }) { + if (isError) { + return ( + +

{error.message}

+
+ ); + } + + if (isLoading) { + return ( + +

Loading...

+
+ ); + } + + if (ogfields?.length > 0) { + return ( + +
    + {ogfields.map((ogfield, index) => ( + + {ogfield.id} + {index < ogfields.length - 1 ? ", " : ""} + + ))} +
+
+
{JSON.stringify(ogfields, null, 2)}
+
+ ); + } + + if (!isLoading && !ogfields?.length) { + return ( + +

+ No ogfields loaded. Click the button above to fetch ogfields. +

+
+ ); + } + + return null; +} + +export default OGFieldsList; diff --git a/deployments/stitch-frontend/src/components/OGFieldsView.jsx b/deployments/stitch-frontend/src/components/OGFieldsView.jsx new file mode 100644 index 0000000..6df6afc --- /dev/null +++ b/deployments/stitch-frontend/src/components/OGFieldsView.jsx @@ -0,0 +1,37 @@ +import { useQueryClient } from "@tanstack/react-query"; +import { useOGFields } from "../hooks/useOGFields"; +import FetchButton from "./FetchButton"; +import ClearCacheButton from "./ClearCacheButton"; +import OGFieldsList from "./OGFieldsList"; +import { ogfieldKeys } from "../queries/ogfields"; + +export default function OGFieldsView({ className, endpoint }) { + const queryClient = useQueryClient(); + const { data, isLoading, isError, error, refetch } = useOGFields(); + + const handleClear = () => { + queryClient.setQueryData(ogfieldKeys.lists(), []); + }; + + return ( +
+

OGFields

+
+ {endpoint} +
+
+ refetch()} isLoading={isLoading} /> + +
+ +
+ ); +} diff --git a/deployments/stitch-frontend/src/hooks/useOGFields.js b/deployments/stitch-frontend/src/hooks/useOGFields.js new file mode 100644 index 0000000..2390e45 --- /dev/null +++ b/deployments/stitch-frontend/src/hooks/useOGFields.js @@ -0,0 +1,10 @@ +import { useAuthenticatedQuery } from "./useAuthenticatedQuery"; +import { ogfieldQueries } from "../queries/ogfields"; + +export function useOGFields() { + return useAuthenticatedQuery(ogfieldQueries.list()); +} + +export function useOGField(id) { + return useAuthenticatedQuery(ogfieldQueries.detail(id)); +} diff --git a/deployments/stitch-frontend/src/queries/api.js b/deployments/stitch-frontend/src/queries/api.js index 91f54d9..459a1aa 100644 --- a/deployments/stitch-frontend/src/queries/api.js +++ b/deployments/stitch-frontend/src/queries/api.js @@ -21,3 +21,25 @@ export async function getResource(id, fetcher) { const data = await response.json(); return data; } + +export async function getOGFields(fetcher) { + const url = `${config.apiBaseUrl}/oil-gas-fields/`; + const response = await fetcher(url); + if (!response.ok) { + throw new Error(`HTTP error! status: ${response.status}`); + } + const data = await response.json(); + return data; +} + +export async function getOGField(id, fetcher) { + const url = `${config.apiBaseUrl}/oil-gas-fields/${id}`; + const response = await fetcher(url); + if (!response.ok) { + const error = new Error(`HTTP error! status: ${response.status}`); + error.status = response.status; + throw error; + } + const data = await response.json(); + return data; +} diff --git a/deployments/stitch-frontend/src/queries/ogfields.js b/deployments/stitch-frontend/src/queries/ogfields.js new file mode 100644 index 0000000..7574143 --- /dev/null +++ b/deployments/stitch-frontend/src/queries/ogfields.js @@ -0,0 +1,25 @@ +import { getOGField, getOGFields } from "./api"; + +// Query key factory - hierarchical for easy invalidation +export const ogfieldKeys = { + all: ["ogfields"], + lists: () => [...ogfieldKeys.all, "list"], + list: (filters) => [...ogfieldKeys.lists(), filters], + details: () => [...ogfieldKeys.all, "detail"], + detail: (id) => [...ogfieldKeys.details(), id], +}; + +// Query definitions +export const ogfieldQueries = { + list: () => ({ + queryKey: ogfieldKeys.lists(), + queryFn: (fetcher) => getOGFields(fetcher), + enabled: false, + }), + + detail: (id) => ({ + queryKey: ogfieldKeys.detail(id), + queryFn: (fetcher) => getOGField(id, fetcher), + enabled: false, + }), +}; diff --git a/packages/stitch-models/src/stitch/models/__init__.py b/packages/stitch-models/src/stitch/models/__init__.py index 5c1ff3e..6b07ead 100644 --- a/packages/stitch-models/src/stitch/models/__init__.py +++ b/packages/stitch-models/src/stitch/models/__init__.py @@ -13,6 +13,8 @@ "Source", "SourcePayload", "SourceRefTuple", + "EmptySourcePayload", + "BaseResource", ] @@ -64,3 +66,12 @@ def _no_self_reference(self) -> Self: if self.repointed_to == self.id: raise ValueError("A resource cannot be repointed to itself") return self + + +class EmptySourcePayload(SourcePayload): + """Domain-agnostic source payload container (no sources).""" + + +# A concrete, domain-agnostic Resource you can use everywhere. +# Domains can replace `EmptySourcePayload` with their own payload type. +type BaseResource[TResId: IdType] = Resource[TResId, EmptySourcePayload] diff --git a/packages/stitch-ogsi/src/stitch/ogsi/model/__init__.py b/packages/stitch-ogsi/src/stitch/ogsi/model/__init__.py index d991d75..88ae876 100644 --- a/packages/stitch-ogsi/src/stitch/ogsi/model/__init__.py +++ b/packages/stitch-ogsi/src/stitch/ogsi/model/__init__.py @@ -61,10 +61,7 @@ class LLMSource(Source[int, LLMSrcKey], OilGasFieldBase): class OGSourcePayload(SourcePayload): - gem: Sequence[GemSource] = Field(default_factory=lambda: []) - wm: Sequence[WoodMacSource] = Field(default_factory=lambda: []) - rmi: Sequence[RMISource] = Field(default_factory=lambda: []) - llm: Sequence[LLMSource] = Field(default_factory=lambda: []) + og_field: Sequence[OGFieldSource] = Field(default_factory=lambda: []) class OGFieldResource(OilGasFieldBase, Resource[int, OGSourcePayload]): ... diff --git a/uv.lock b/uv.lock index 3a6cadc..78643d8 100644 --- a/uv.lock +++ b/uv.lock @@ -1126,6 +1126,8 @@ dependencies = [ { name = "pydantic-settings" }, { name = "sqlalchemy" }, { name = "stitch-auth" }, + { name = "stitch-models" }, + { name = "stitch-ogsi" }, ] [package.dev-dependencies] @@ -1144,6 +1146,8 @@ requires-dist = [ { name = "pydantic-settings", specifier = ">=2.12.0" }, { name = "sqlalchemy", specifier = ">=2.0.44" }, { name = "stitch-auth", editable = "packages/stitch-auth" }, + { name = "stitch-models", editable = "packages/stitch-models" }, + { name = "stitch-ogsi", editable = "packages/stitch-ogsi" }, ] [package.metadata.requires-dev]