From 2fa6e4598ec4ae109903222e5c7afd7459d1f977 Mon Sep 17 00:00:00 2001 From: Alex Axthelm Date: Tue, 3 Mar 2026 14:02:11 +0100 Subject: [PATCH 01/25] Add dependency on stitch models for API --- deployments/api/pyproject.toml | 2 ++ deployments/api/src/stitch/api/entities.py | 13 +----------- .../api/src/stitch/api/resources/entities.py | 20 +++++++++++++++++++ .../src/stitch/models/__init__.py | 11 ++++++++++ uv.lock | 2 ++ 5 files changed, 36 insertions(+), 12 deletions(-) create mode 100644 deployments/api/src/stitch/api/resources/entities.py diff --git a/deployments/api/pyproject.toml b/deployments/api/pyproject.toml index 3aed224..869825b 100644 --- a/deployments/api/pyproject.toml +++ b/deployments/api/pyproject.toml @@ -12,6 +12,7 @@ dependencies = [ "pydantic-settings>=2.12.0", "sqlalchemy>=2.0.44", "stitch-auth", + "stitch-models", ] [project.scripts] @@ -41,3 +42,4 @@ addopts = ["-v", "--strict-markers", "--tb=short"] [tool.uv.sources] stitch-auth = { workspace = true } +stitch-models = { workspace = true } diff --git a/deployments/api/src/stitch/api/entities.py b/deployments/api/src/stitch/api/entities.py index 337a5ee..d1cbbd1 100644 --- a/deployments/api/src/stitch/api/entities.py +++ b/deployments/api/src/stitch/api/entities.py @@ -5,21 +5,10 @@ Generic, Literal, Mapping, - Protocol, TypeVar, - runtime_checkable, ) -from uuid import UUID from pydantic import BaseModel, ConfigDict, EmailStr, Field - -IdType = int | str | UUID - - -@runtime_checkable -class HasId(Protocol): - @property - def id(self) -> IdType: ... - +from stitch.models.types import IdType GEM_SRC = Literal["gem"] WM_SRC = Literal["wm"] diff --git a/deployments/api/src/stitch/api/resources/entities.py b/deployments/api/src/stitch/api/resources/entities.py new file mode 100644 index 0000000..3af1139 --- /dev/null +++ b/deployments/api/src/stitch/api/resources/entities.py @@ -0,0 +1,20 @@ +from __future__ import annotations + +from pydantic import BaseModel, Field +from stitch.models import BaseResource, EmptySourcePayload + + +class CreateResource(BaseModel): + """Domain-agnostic create model.""" + + # optional generic metadata; domains can replace/extend later + label: str | None = None + + +class Resource(BaseModel): + """Domain-agnostic read model.""" + + resource: BaseResource[int] = Field( + default_factory=lambda: BaseResource[int](source_data=EmptySourcePayload()) + ) + label: str | None = None diff --git a/packages/stitch-models/src/stitch/models/__init__.py b/packages/stitch-models/src/stitch/models/__init__.py index 5c1ff3e..6b07ead 100644 --- a/packages/stitch-models/src/stitch/models/__init__.py +++ b/packages/stitch-models/src/stitch/models/__init__.py @@ -13,6 +13,8 @@ "Source", "SourcePayload", "SourceRefTuple", + "EmptySourcePayload", + "BaseResource", ] @@ -64,3 +66,12 @@ def _no_self_reference(self) -> Self: if self.repointed_to == self.id: raise ValueError("A resource cannot be repointed to itself") return self + + +class EmptySourcePayload(SourcePayload): + """Domain-agnostic source payload container (no sources).""" + + +# A concrete, domain-agnostic Resource you can use everywhere. +# Domains can replace `EmptySourcePayload` with their own payload type. +type BaseResource[TResId: IdType] = Resource[TResId, EmptySourcePayload] diff --git a/uv.lock b/uv.lock index 05458cc..83a5486 100644 --- a/uv.lock +++ b/uv.lock @@ -1125,6 +1125,7 @@ dependencies = [ { name = "pydantic-settings" }, { name = "sqlalchemy" }, { name = "stitch-auth" }, + { name = "stitch-models" }, ] [package.dev-dependencies] @@ -1143,6 +1144,7 @@ requires-dist = [ { name = "pydantic-settings", specifier = ">=2.12.0" }, { name = "sqlalchemy", specifier = ">=2.0.44" }, { name = "stitch-auth", editable = "packages/stitch-auth" }, + { name = "stitch-models", editable = "packages/stitch-models" }, ] [package.metadata.requires-dev] From 2d08ac87e3ee6fafb554394a54906e86f41b7139 Mon Sep 17 00:00:00 2001 From: Alex Axthelm Date: Wed, 4 Mar 2026 14:42:12 +0100 Subject: [PATCH 02/25] remove previous source model in favor of generic --- .../api/src/stitch/api/db/model/__init__.py | 10 -- .../api/src/stitch/api/db/model/resource.py | 35 +--- .../api/src/stitch/api/db/model/sources.py | 141 ---------------- .../api/src/stitch/api/db/resource_actions.py | 121 ++------------ deployments/api/src/stitch/api/entities.py | 155 +----------------- .../api/src/stitch/api/resources/entities.py | 12 +- .../api/src/stitch/api/routers/resources.py | 2 +- 7 files changed, 25 insertions(+), 451 deletions(-) delete mode 100644 deployments/api/src/stitch/api/db/model/sources.py diff --git a/deployments/api/src/stitch/api/db/model/__init__.py b/deployments/api/src/stitch/api/db/model/__init__.py index d3f74ab..9a20770 100644 --- a/deployments/api/src/stitch/api/db/model/__init__.py +++ b/deployments/api/src/stitch/api/db/model/__init__.py @@ -1,21 +1,11 @@ from .common import Base as StitchBase -from .sources import ( - GemSourceModel, - RMIManualSourceModel, - CCReservoirsSourceModel, - WMSourceModel, -) from .resource import MembershipStatus, MembershipModel, ResourceModel from .user import User as UserModel __all__ = [ - "CCReservoirsSourceModel", - "GemSourceModel", "MembershipModel", "MembershipStatus", - "RMIManualSourceModel", "ResourceModel", "StitchBase", "UserModel", - "WMSourceModel", ] diff --git a/deployments/api/src/stitch/api/db/model/resource.py b/deployments/api/src/stitch/api/db/model/resource.py index 5493faa..48f0a97 100644 --- a/deployments/api/src/stitch/api/db/model/resource.py +++ b/deployments/api/src/stitch/api/db/model/resource.py @@ -1,4 +1,3 @@ -from collections import defaultdict from enum import StrEnum from sqlalchemy import ( ForeignKey, @@ -12,13 +11,8 @@ from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import Mapped, mapped_column, relationship -from .sources import ( - SOURCE_TABLES, - SourceKey, - SourceModel, - SourceModelData, -) -from stitch.api.entities import IdType, User as UserEntity +from stitch.models.types import IdType +from stitch.api.entities import User as UserEntity from .common import Base from .mixins import TimestampMixin, UserAuditMixin from .types import PORTABLE_BIGINT @@ -43,9 +37,7 @@ class MembershipModel(TimestampMixin, UserAuditMixin, Base): ) id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) resource_id: Mapped[int] = mapped_column(ForeignKey("resources.id"), nullable=False) - source: Mapped[SourceKey] = mapped_column( - String(10), nullable=False - ) # "gem" | "wm" + source: Mapped[str] = mapped_column(String(10), nullable=False) # "gem" | "wm" source_pk: Mapped[int] = mapped_column(PORTABLE_BIGINT, nullable=False) status: Mapped[MembershipStatus] @@ -54,14 +46,14 @@ def create( cls, created_by: UserEntity, resource: "ResourceModel", - source: SourceKey, + source: str, source_pk: IdType, status: MembershipStatus = MembershipStatus.ACTIVE, ): model = cls( resource_id=resource.id, source=source, - source_pk=str(source_pk), + source_pk=int(source_pk), status=status, created_by_id=created_by.id, last_updated_by_id=created_by.id, @@ -98,23 +90,6 @@ class ResourceModel(TimestampMixin, UserAuditMixin, Base): # and configure the appropriate SQL statement to load the membership objects memberships: Mapped[list[MembershipModel]] = relationship() - async def get_source_data(self, session: AsyncSession): - pks_by_src: dict[SourceKey, set[int]] = defaultdict(set) - for mem in self.memberships: - if mem.status == MembershipStatus.ACTIVE: - pks_by_src[mem.source].add(mem.source_pk) - - results: dict[SourceKey, dict[IdType, SourceModel]] = defaultdict(dict) - for src, pks in pks_by_src.items(): - model_cls = SOURCE_TABLES.get(src) - if model_cls is None: - continue - stmt = select(model_cls).where(model_cls.id.in_(pks)) - for src_model in await session.scalars(stmt): - results[src][src_model.id] = src_model - - return SourceModelData(**results) - async def get_root(self, session: AsyncSession): root = await session.scalar(self.__class__._root_select(self.id)) if root is None: diff --git a/deployments/api/src/stitch/api/db/model/sources.py b/deployments/api/src/stitch/api/db/model/sources.py deleted file mode 100644 index 6b8e5af..0000000 --- a/deployments/api/src/stitch/api/db/model/sources.py +++ /dev/null @@ -1,141 +0,0 @@ -from typing_extensions import Self - -from collections.abc import Mapping, MutableMapping -from typing import Final, Generic, TypeVar, TypedDict, get_args, get_origin -from pydantic import BaseModel -from sqlalchemy import CheckConstraint, inspect -from sqlalchemy.orm import Mapped, mapped_column -from .common import Base -from .types import PORTABLE_BIGINT, StitchJson -from stitch.api.entities import ( - CCReservoirsSource, - GemSource, - IdType, - RMIManualSource, - SourceKey, - WMData, - GemData, - RMIManualData, - CCReservoirsData, - WMSource, -) - - -def float_constraint( - colname: str, min_: float | None = None, max_: float | None = None -) -> CheckConstraint: - min_str = f"{colname} >= {min_}" if min_ is not None else None - max_str = f"{colname} <= {max_}" if max_ is not None else None - expr = " AND ".join(filter(None, (min_str, max_str))) - return CheckConstraint(expr) - - -def lat_constraints(colname: str): - return float_constraint(colname, -90, 90) - - -def lon_constraints(colname: str): - return float_constraint(colname, -180, 180) - - -TModelIn = TypeVar("TModelIn", bound=BaseModel) -TModelOut = TypeVar("TModelOut", bound=BaseModel) - - -class SourceBase(Base, Generic[TModelIn, TModelOut]): - __abstract__ = True - __entity_class_in__: type[TModelIn] - __entity_class_out__: type[TModelOut] - - id: Mapped[int] = mapped_column( - PORTABLE_BIGINT, primary_key=True, autoincrement=True - ) - - def __init_subclass__(cls, **kwargs) -> None: - super().__init_subclass__(**kwargs) - for base in getattr(cls, "__orig_bases__", ()): - if get_origin(base) is SourceBase: - args = get_args(base) - if len(args) >= 2: - if isinstance(args[0], type): - cls.__entity_class_in__ = args[0] - if isinstance(args[1], type): - cls.__entity_class_out__ = args[1] - break - - def as_entity(self): - return self.__entity_class_out__.model_validate(self) - - @classmethod - def from_entity(cls, entity: TModelIn) -> Self: - mapper = inspect(cls) - column_keys = {col.key for col in mapper.columns} - filtered = {k: v for k, v in entity.model_dump().items() if k in column_keys} - return cls(**filtered) - - -class GemSourceModel(SourceBase[GemData, GemSource]): - __tablename__ = "gem_sources" - - name: Mapped[str] - country: Mapped[str] - lat: Mapped[float] = mapped_column(lat_constraints("lat")) - lon: Mapped[float] = mapped_column(lon_constraints("lon")) - - -class WMSourceModel(SourceBase[WMData, WMSource]): - __tablename__ = "wm_sources" - - field_name: Mapped[str] - field_country: Mapped[str] - production: Mapped[float] - - -class RMIManualSourceModel(SourceBase[RMIManualData, RMIManualSource]): - __tablename__ = "rmi_manual_sources" - - name_override: Mapped[str | None] - gwp: Mapped[float | None] - gor: Mapped[float | None | None] = mapped_column( - float_constraint("gor", 0, 1), nullable=True - ) - country: Mapped[str | None] - latitude: Mapped[float | None] = mapped_column( - lat_constraints("latitude"), nullable=True - ) - longitude: Mapped[float | None] = mapped_column( - lon_constraints("longitude"), nullable=True - ) - - -class CCReservoirsSourceModel(SourceBase[CCReservoirsData, CCReservoirsSource]): - __tablename__ = "cc_reservoirs_sources" - - name: Mapped[str] - basin: Mapped[str] - depth: Mapped[float] - geofence: Mapped[list[tuple[float, float]]] = mapped_column(StitchJson()) - - -SourceModel = ( - GemSourceModel | WMSourceModel | RMIManualSourceModel | CCReservoirsSourceModel -) -SourceModelCls = type[SourceModel] - -SOURCE_TABLES: Final[Mapping[SourceKey, SourceModelCls]] = { - "gem": GemSourceModel, - "wm": WMSourceModel, - "rmi": RMIManualSourceModel, - "cc": CCReservoirsSourceModel, -} - - -class SourceModelData(TypedDict, total=False): - gem: MutableMapping[IdType, GemSourceModel] - wm: MutableMapping[IdType, WMSourceModel] - cc: MutableMapping[IdType, CCReservoirsSourceModel] - rmi: MutableMapping[IdType, RMIManualSourceModel] - - -def empty_source_model_data(): - return SourceModelData(gem={}, wm={}, cc={}, rmi={}) diff --git a/deployments/api/src/stitch/api/db/resource_actions.py b/deployments/api/src/stitch/api/db/resource_actions.py index c0d19bc..1ad9358 100644 --- a/deployments/api/src/stitch/api/db/resource_actions.py +++ b/deployments/api/src/stitch/api/db/resource_actions.py @@ -1,93 +1,38 @@ import asyncio -from collections import defaultdict -from collections.abc import Mapping, Sequence +from collections.abc import Sequence from functools import partial from fastapi import HTTPException from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy.orm import selectinload from starlette.status import HTTP_404_NOT_FOUND - -from stitch.api.db.model.sources import SOURCE_TABLES, SourceModel +from stitch.api.resources.entities import CreateResource, Resource from stitch.api.auth import CurrentUser -from stitch.api.entities import ( - CreateResource, - CreateResourceSourceData, - CreateSourceData, - Resource, - SourceData, - SourceKey, -) from .model import ( - CCReservoirsSourceModel, - GemSourceModel, - MembershipModel, - RMIManualSourceModel, ResourceModel, - WMSourceModel, ) -async def get_or_create_source_models( - session: AsyncSession, - data: CreateResourceSourceData, -) -> Mapping[SourceKey, Sequence[SourceModel]]: - result: dict[SourceKey, list[SourceModel]] = defaultdict(list) - for key, model_cls in SOURCE_TABLES.items(): - for item in data.get(key): - if isinstance(item, int): - src_model = await session.get(model_cls, item) - if src_model is None: - continue - result[key].append(src_model) - else: - result[key].append(model_cls.from_entity(item)) - return result - - -def resource_model_to_empty_entity(model: ResourceModel): - return Resource( - id=model.id, - name=model.name, - country=model.country, - source_data=SourceData(), - constituents=[], - created=model.created, - updated=model.updated, - ) - - async def resource_model_to_entity( session: AsyncSession, model: ResourceModel ) -> Resource: - source_model_data = await model.get_source_data(session) - source_data = SourceData.model_validate(source_model_data) + # Domain-agnostic: constituents are just ids, and there is no source_data. constituent_models = await ResourceModel.get_constituents_by_root_id( session, model.id ) - constituents = [ - resource_model_to_empty_entity(cm) - for cm in constituent_models - if cm.id != model.id - ] + constituent_ids = [m.id for m in constituent_models if m.id != model.id] return Resource( id=model.id, name=model.name, - country=model.country, - source_data=source_data, - constituents=constituents, - created=model.created, - updated=model.updated, + repointed_to=model.repointed_id, + constituents=constituent_ids, + created=str(model.created) if getattr(model, "created", None) else None, + updated=str(model.updated) if getattr(model, "updated", None) else None, ) async def get_all(session: AsyncSession) -> Sequence[Resource]: - stmt = ( - select(ResourceModel) - .where(ResourceModel.repointed_id.is_(None)) - .options(selectinload(ResourceModel.memberships)) - ) + stmt = select(ResourceModel).where(ResourceModel.repointed_id.is_(None)) models = (await session.scalars(stmt)).all() fn = partial(resource_model_to_entity, session) return await asyncio.gather(*[fn(m) for m in models]) @@ -99,58 +44,18 @@ async def get(session: AsyncSession, id: int): raise HTTPException( status_code=HTTP_404_NOT_FOUND, detail=f"No Resource with id `{id}` found." ) - await session.refresh(model, ["memberships"]) return await resource_model_to_entity(session, model) async def create(session: AsyncSession, user: CurrentUser, resource: CreateResource): """ - Here we create a resource either from new source data or existing source data. It's also possible - to create an empty resource with no reference to source data. - - - create the resource - - create the sources - - create membership + Domain-agnostic create: + - create the resource row only """ model = ResourceModel.create( - created_by=user, name=resource.name, country=resource.country + created_by=user, + name=resource.name, ) session.add(model) - if resource.source_data: - src_model_groups = await get_or_create_source_models( - session, resource.source_data - ) - for src_key, src_models in src_model_groups.items(): - session.add_all(src_models) - await session.flush() - for src_model in src_models: - session.add( - MembershipModel.create( - created_by=user, - resource=model, - source=src_key, - source_pk=src_model.id, - ) - ) await session.flush() - await session.refresh(model, ["memberships"]) return await resource_model_to_entity(session, model) - - -async def create_source_data(session: AsyncSession, data: CreateSourceData): - """ - For bulk inserting data into source tables. - """ - gems = tuple(GemSourceModel.from_entity(gem) for gem in data.gem) - wms = tuple(WMSourceModel.from_entity(wm) for wm in data.wm) - rmis = tuple(RMIManualSourceModel.from_entity(rmi) for rmi in data.rmi) - ccs = tuple(CCReservoirsSourceModel.from_entity(cc) for cc in data.cc) - - session.add_all(gems + wms + rmis + ccs) - await session.flush() - return SourceData( - gem={g.id: g.as_entity() for g in gems}, - wm={wm.id: wm.as_entity() for wm in wms}, - rmi={rmi.id: rmi.as_entity() for rmi in rmis}, - cc={cc.id: cc.as_entity() for cc in ccs}, - ) diff --git a/deployments/api/src/stitch/api/entities.py b/deployments/api/src/stitch/api/entities.py index d1cbbd1..fecbfab 100644 --- a/deployments/api/src/stitch/api/entities.py +++ b/deployments/api/src/stitch/api/entities.py @@ -1,154 +1,4 @@ -from collections.abc import Sequence -from datetime import datetime -from typing import ( - Annotated, - Generic, - Literal, - Mapping, - TypeVar, -) -from pydantic import BaseModel, ConfigDict, EmailStr, Field -from stitch.models.types import IdType - -GEM_SRC = Literal["gem"] -WM_SRC = Literal["wm"] -RMI_SRC = Literal["rmi"] -CC_SRC = Literal["cc"] - -SourceKey = GEM_SRC | WM_SRC | RMI_SRC | CC_SRC - -TSourceKey = TypeVar("TSourceKey", bound=SourceKey) - - -class Timestamped(BaseModel): - created: datetime = Field(default_factory=datetime.now) - updated: datetime = Field(default_factory=datetime.now) - - -class Identified(BaseModel): - id: IdType - - -class SourceBase(BaseModel, Generic[TSourceKey]): - source: TSourceKey - id: IdType - - -class SourceRef(BaseModel): - source: SourceKey - id: int - - -# The sources will come in and be initially stored in a raw table. -# That raw table will be an append-only table. -# We'll translate that data into one of the below structures, so each source will have a `UUID` or similar that -# references their id in the "raw" table. -# When pulling into the internal "sources" table, each will get a new unique id which is what the memberships will reference - - -class GemData(BaseModel): - name: str - lat: float = Field(ge=-90, le=90) - lon: float = Field(ge=-180, le=180) - country: str - - -class GemSource(Identified, GemData): - model_config = ConfigDict(from_attributes=True) - - -class WMData(BaseModel): - field_name: str - field_country: str - production: float - - -class WMSource(Identified, WMData): - model_config = ConfigDict(from_attributes=True) - - -class RMIManualData(BaseModel): - name_override: str - gwp: float - gor: float = Field(gt=0, lt=1) - country: str - latitude: float = Field(ge=-90, le=90) - longitude: float = Field(ge=-180, le=180) - - -class RMIManualSource(Identified, RMIManualData): - model_config = ConfigDict(from_attributes=True) - - -class CCReservoirsData(BaseModel): - name: str - basin: str - depth: float - geofence: Sequence[tuple[float, float]] - - -class CCReservoirsSource(Identified, CCReservoirsData): - model_config = ConfigDict(from_attributes=True) - - -OGSISourcePayload = Annotated[ - GemSource | WMSource | RMIManualSource | CCReservoirsSource, - Field(discriminator="source"), -] - - -class SourceData(BaseModel): - gem: Mapping[IdType, GemSource] = Field(default_factory=dict) - wm: Mapping[IdType, WMSource] = Field(default_factory=dict) - rmi: Mapping[IdType, RMIManualSource] = Field(default_factory=dict) - cc: Mapping[IdType, CCReservoirsSource] = Field(default_factory=dict) - - -class CreateSourceData(BaseModel): - gem: Sequence[GemData] = Field(default_factory=list) - wm: Sequence[WMData] = Field(default_factory=list) - rmi: Sequence[RMIManualData] = Field(default_factory=list) - cc: Sequence[CCReservoirsData] = Field(default_factory=list) - - -class CreateResourceSourceData(BaseModel): - """Allows for creating source data or referencing existing sources by ID. - - It can be used in isolation to insert source data or used with a new/existing resource to automatically add - memberships to the resource. - """ - - gem: Sequence[GemData | int] = Field(default_factory=list) - wm: Sequence[WMData | int] = Field(default_factory=list) - rmi: Sequence[RMIManualData | int] = Field(default_factory=list) - cc: Sequence[CCReservoirsData | int] = Field(default_factory=list) - - def get(self, key: SourceKey): - if key == "gem": - return self.gem - elif key == "wm": - return self.wm - elif key == "rmi": - return self.rmi - elif key == "cc": - return self.cc - raise ValueError(f"Unknown source key: {key}") - - -class ResourceBase(BaseModel): - name: str | None = Field(default=None) - country: str | None = Field(default=None) - repointed_to: "Resource | None" = Field(default=None) - - -class Resource(ResourceBase, Timestamped): - id: int - source_data: SourceData - constituents: Sequence["Resource"] - - -class CreateResource(ResourceBase): - source_data: CreateResourceSourceData | None +from pydantic import BaseModel, EmailStr, Field class User(BaseModel): @@ -157,6 +7,3 @@ class User(BaseModel): role: str | None = None email: EmailStr name: str - - -class SourceSelectionLogic(BaseModel): ... diff --git a/deployments/api/src/stitch/api/resources/entities.py b/deployments/api/src/stitch/api/resources/entities.py index 3af1139..ba835bb 100644 --- a/deployments/api/src/stitch/api/resources/entities.py +++ b/deployments/api/src/stitch/api/resources/entities.py @@ -1,20 +1,18 @@ from __future__ import annotations from pydantic import BaseModel, Field -from stitch.models import BaseResource, EmptySourcePayload class CreateResource(BaseModel): """Domain-agnostic create model.""" - # optional generic metadata; domains can replace/extend later - label: str | None = None + name: str | None = None class Resource(BaseModel): """Domain-agnostic read model.""" - resource: BaseResource[int] = Field( - default_factory=lambda: BaseResource[int](source_data=EmptySourcePayload()) - ) - label: str | None = None + id: int + name: str | None = None + repointed_to: int | None = None + constituents: frozenset[int] = Field(default_factory=frozenset) diff --git a/deployments/api/src/stitch/api/routers/resources.py b/deployments/api/src/stitch/api/routers/resources.py index 8d63c15..a2e7d94 100644 --- a/deployments/api/src/stitch/api/routers/resources.py +++ b/deployments/api/src/stitch/api/routers/resources.py @@ -5,7 +5,7 @@ from stitch.api.db import resource_actions from stitch.api.db.config import UnitOfWorkDep from stitch.api.auth import CurrentUser -from stitch.api.entities import CreateResource, Resource +from stitch.api.resources.entities import CreateResource, Resource router = APIRouter( From 5ef1c2587c17c4a2857ac480bc57c55d547871ab Mon Sep 17 00:00:00 2001 From: Alex Axthelm Date: Wed, 4 Mar 2026 20:04:24 +0100 Subject: [PATCH 03/25] Remove domain model from init_job seed data --- deployments/api/src/stitch/api/db/init_job.py | 89 +------------------ .../api/src/stitch/api/db/model/mixins.py | 9 +- 2 files changed, 7 insertions(+), 91 deletions(-) diff --git a/deployments/api/src/stitch/api/db/init_job.py b/deployments/api/src/stitch/api/db/init_job.py index 5a256c2..b9f17f8 100644 --- a/deployments/api/src/stitch/api/db/init_job.py +++ b/deployments/api/src/stitch/api/db/init_job.py @@ -12,20 +12,12 @@ from sqlalchemy.orm import Session from stitch.api.db.model import ( - CCReservoirsSourceModel, - GemSourceModel, - MembershipModel, - RMIManualSourceModel, ResourceModel, StitchBase, UserModel, - WMSourceModel, ) from stitch.api.entities import ( - GemData, - RMIManualData, User as UserEntity, - WMData, ) """ @@ -273,81 +265,16 @@ def create_dev_user() -> UserModel: ) -def create_seed_sources(): - gem_sources = [ - GemSourceModel.from_entity( - GemData(name="Permian Basin Field", country="USA", lat=31.8, lon=-102.3) - ), - GemSourceModel.from_entity( - GemData(name="North Sea Platform", country="GBR", lat=57.5, lon=1.5) - ), - ] - for i, src in enumerate(gem_sources, start=1): - src.id = i - - wm_sources = [ - WMSourceModel.from_entity( - WMData( - field_name="Eagle Ford Shale", field_country="USA", production=125000.5 - ) - ), - WMSourceModel.from_entity( - WMData(field_name="Ghawar Field", field_country="SAU", production=500000.0) - ), - ] - for i, src in enumerate(wm_sources, start=1): - src.id = i - - rmi_sources = [ - RMIManualSourceModel.from_entity( - RMIManualData( - name_override="Custom Override Name", - gwp=25.5, - gor=0.45, - country="CAN", - latitude=56.7, - longitude=-111.4, - ) - ), - ] - for i, src in enumerate(rmi_sources, start=1): - src.id = i - - # CC Reservoir sources are intentionally omitted from the dev seed profile; - # the CCReservoirsSourceModel table is still created from SQLAlchemy metadata. - cc_sources: list[CCReservoirsSourceModel] = [] - - return gem_sources, wm_sources, rmi_sources, cc_sources - - def create_seed_resources(user: UserEntity) -> list[ResourceModel]: resources = [ - ResourceModel.create(user, name="Multi-Source Asset", country="USA"), - ResourceModel.create(user, name="Single Source Asset", country="GBR"), + ResourceModel.create(user, name="Resource Foo01"), + ResourceModel.create(user, name="Resource Bar01"), ] for i, res in enumerate(resources, start=1): res.id = i return resources -def create_seed_memberships( - user: UserEntity, - resources: list[ResourceModel], - gem_sources: list[GemSourceModel], - wm_sources: list[WMSourceModel], - rmi_sources: list[RMIManualSourceModel], -) -> list[MembershipModel]: - memberships = [ - MembershipModel.create(user, resources[0], "gem", gem_sources[0].id), - MembershipModel.create(user, resources[0], "wm", wm_sources[0].id), - MembershipModel.create(user, resources[0], "rmi", rmi_sources[0].id), - MembershipModel.create(user, resources[1], "gem", gem_sources[1].id), - ] - for i, mem in enumerate(memberships, start=1): - mem.id = i - return memberships - - def reset_sequences(engine, tables: Iterable[str]) -> None: with engine.begin() as conn: for t in tables: @@ -383,27 +310,15 @@ def seed_dev(engine) -> None: name=dev_model.name, ) - gem_sources, wm_sources, rmi_sources, cc_sources = create_seed_sources() - session.add_all(gem_sources + wm_sources + rmi_sources + cc_sources) - resources = create_seed_resources(user_entity) - resources = create_seed_resources(dev_entity) session.add_all(resources) - memberships = create_seed_memberships( - user_entity, resources, gem_sources, wm_sources, rmi_sources - ) - session.add_all(memberships) - session.commit() reset_sequences( engine, tables=[ "users", - "gem_sources", - "wm_sources", - "rmi_manual_sources", "resources", "memberships", ], diff --git a/deployments/api/src/stitch/api/db/model/mixins.py b/deployments/api/src/stitch/api/db/model/mixins.py index b579582..e689933 100644 --- a/deployments/api/src/stitch/api/db/model/mixins.py +++ b/deployments/api/src/stitch/api/db/model/mixins.py @@ -1,11 +1,10 @@ from datetime import datetime from typing import Any, ClassVar, Generic, TypeVar, get_args, get_origin -from pydantic import TypeAdapter +from pydantic import BaseModel, TypeAdapter from sqlalchemy import DateTime, ForeignKey, String, func from sqlalchemy.ext.hybrid import hybrid_property from sqlalchemy.orm import Mapped, declarative_mixin, mapped_column -from stitch.api.entities import SourceBase from .types import StitchJson @@ -30,7 +29,7 @@ class UserAuditMixin: last_updated_by_id: Mapped[int] = mapped_column(ForeignKey("users.id")) -TPayload = TypeVar("TPayload", bound=SourceBase) +TPayload = TypeVar("TPayload", bound=BaseModel) def _extract_payload_type(cls: type) -> type | None: @@ -63,7 +62,9 @@ def payload(self) -> TPayload: @payload.inplace.setter def _payload_setter(self, value: TPayload): - self.source = value.source + # Domain-agnostic: if the payload has a `source` attribute, keep it; + # otherwise set a neutral default. + self.source = getattr(value, "source", "unknown") self._payload_data = value.model_dump(mode="json") @payload.inplace.expression From e68fcd7ad0ce3dac92fc7dc9dd75ff10d7553bf7 Mon Sep 17 00:00:00 2001 From: Alex Axthelm Date: Wed, 4 Mar 2026 20:05:46 +0100 Subject: [PATCH 04/25] linting: unused object --- deployments/api/src/stitch/api/db/init_job.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/deployments/api/src/stitch/api/db/init_job.py b/deployments/api/src/stitch/api/db/init_job.py index b9f17f8..06b88f5 100644 --- a/deployments/api/src/stitch/api/db/init_job.py +++ b/deployments/api/src/stitch/api/db/init_job.py @@ -303,13 +303,6 @@ def seed_dev(engine) -> None: name=user_model.name, ) - dev_entity = UserEntity( - id=dev_model.id, - sub=dev_model.sub, - email=dev_model.email, - name=dev_model.name, - ) - resources = create_seed_resources(user_entity) session.add_all(resources) From 6ec0a17930a3c9e11b09c97036153b2f548b7472 Mon Sep 17 00:00:00 2001 From: Alex Axthelm Date: Wed, 4 Mar 2026 20:56:33 +0100 Subject: [PATCH 05/25] update tests --- .../api/src/stitch/api/resources/entities.py | 3 +- deployments/api/tests/conftest.py | 125 ------ .../api/tests/db/test_resource_actions.py | 371 ++---------------- .../routers/test_resources_integration.py | 35 +- .../api/tests/routers/test_resources_unit.py | 47 +-- deployments/api/tests/utils.py | 246 +----------- 6 files changed, 60 insertions(+), 767 deletions(-) diff --git a/deployments/api/src/stitch/api/resources/entities.py b/deployments/api/src/stitch/api/resources/entities.py index ba835bb..e1f1471 100644 --- a/deployments/api/src/stitch/api/resources/entities.py +++ b/deployments/api/src/stitch/api/resources/entities.py @@ -1,11 +1,12 @@ from __future__ import annotations -from pydantic import BaseModel, Field +from pydantic import BaseModel, ConfigDict, Field class CreateResource(BaseModel): """Domain-agnostic create model.""" + model_config = ConfigDict(extra="forbid") name: str | None = None diff --git a/deployments/api/tests/conftest.py b/deployments/api/tests/conftest.py index 41ee121..1782c11 100644 --- a/deployments/api/tests/conftest.py +++ b/deployments/api/tests/conftest.py @@ -9,24 +9,13 @@ from stitch.api.db.config import UnitOfWork, get_uow from stitch.api.db.model import ( - CCReservoirsSourceModel, - GemSourceModel, - RMIManualSourceModel, StitchBase, UserModel, - WMSourceModel, ) from stitch.api.auth import get_current_user from stitch.api.entities import User from stitch.api.main import app -from .utils import ( - CC_DEFAULTS, - GEM_DEFAULTS, - RMI_DEFAULTS, - WM_DEFAULTS, -) - @pytest.fixture def anyio_backend() -> str: @@ -184,117 +173,3 @@ def override_get_current_user() -> User: base_url="http://test/api/v1", ) as ac: yield ac - - -@pytest.fixture -async def existing_gem_source( - seeded_integration_session: AsyncSession, -) -> GemSourceModel: - """Pre-create a GEM source in DB, return model with ID.""" - model = GemSourceModel( - name=GEM_DEFAULTS["name"], - lat=GEM_DEFAULTS["lat"], - lon=GEM_DEFAULTS["lon"], - country=GEM_DEFAULTS["country"], - ) - seeded_integration_session.add(model) - await seeded_integration_session.flush() - return model - - -@pytest.fixture -async def existing_wm_source( - seeded_integration_session: AsyncSession, -) -> WMSourceModel: - """Pre-create a WM source in DB, return model with ID.""" - model = WMSourceModel( - field_name=WM_DEFAULTS["field_name"], - field_country=WM_DEFAULTS["field_country"], - production=WM_DEFAULTS["production"], - ) - seeded_integration_session.add(model) - await seeded_integration_session.flush() - return model - - -@pytest.fixture -async def existing_rmi_source( - seeded_integration_session: AsyncSession, -) -> RMIManualSourceModel: - """Pre-create an RMI source in DB, return model with ID.""" - model = RMIManualSourceModel( - name_override=RMI_DEFAULTS["name_override"], - gwp=RMI_DEFAULTS["gwp"], - gor=RMI_DEFAULTS["gor"], - country=RMI_DEFAULTS["country"], - latitude=RMI_DEFAULTS["latitude"], - longitude=RMI_DEFAULTS["longitude"], - ) - seeded_integration_session.add(model) - await seeded_integration_session.flush() - return model - - -@pytest.fixture -async def existing_cc_source( - seeded_integration_session: AsyncSession, -) -> CCReservoirsSourceModel: - """Pre-create a CC source in DB, return model with ID.""" - model = CCReservoirsSourceModel( - name=CC_DEFAULTS["name"], - basin=CC_DEFAULTS["basin"], - depth=CC_DEFAULTS["depth"], - geofence=list(CC_DEFAULTS["geofence"]), - ) - seeded_integration_session.add(model) - await seeded_integration_session.flush() - return model - - -@pytest.fixture -async def existing_sources( - seeded_integration_session: AsyncSession, -) -> dict[str, list[int]]: - """Create 2 of each source type, return dict mapping source key to list of IDs.""" - session = seeded_integration_session - - gems = [ - GemSourceModel(name=f"GEM {i}", lat=45.0 + i, lon=-120.0 + i, country="USA") - for i in range(2) - ] - wms = [ - WMSourceModel( - field_name=f"WM Field {i}", field_country="USA", production=1000.0 * (i + 1) - ) - for i in range(2) - ] - rmis = [ - RMIManualSourceModel( - name_override=f"RMI {i}", - gwp=25.0, - gor=0.5, - country="USA", - latitude=40.0 + i, - longitude=-100.0 + i, - ) - for i in range(2) - ] - ccs = [ - CCReservoirsSourceModel( - name=f"CC Reservoir {i}", - basin="Permian", - depth=3000.0, - geofence=[(0.0, 0.0), (1.0, 0.0), (1.0, 1.0), (0.0, 1.0)], - ) - for i in range(2) - ] - - session.add_all(gems + wms + rmis + ccs) - await session.flush() - - return { - "gem": [g.id for g in gems], - "wm": [w.id for w in wms], - "rmi": [r.id for r in rmis], - "cc": [c.id for c in ccs], - } diff --git a/deployments/api/tests/db/test_resource_actions.py b/deployments/api/tests/db/test_resource_actions.py index 39538a2..a2a83b1 100644 --- a/deployments/api/tests/db/test_resource_actions.py +++ b/deployments/api/tests/db/test_resource_actions.py @@ -1,52 +1,25 @@ -"""Database integration tests for resource_actions module.""" +"""Database integration tests for domain-agnostic resource_actions.""" import pytest from fastapi import HTTPException -from sqlalchemy import func, select from sqlalchemy.ext.asyncio import AsyncSession from stitch.api.db import resource_actions -from stitch.api.db.model import ( - GemSourceModel, - MembershipModel, - ResourceModel, -) -from stitch.api.entities import CreateSourceData, GemData, User, WMData - -from tests.utils import ( - make_cc_data, - make_create_resource, - make_empty_resource, - make_gem_data, - make_resource_with_existing_ids, - make_resource_with_mixed_sources, - make_resource_with_new_sources, - make_rmi_data, - make_source_data, - make_wm_data, -) - - -class TestGetResourceActionUnit: ... - - -class TestCreateResourceActionUnit: ... - - -class TestCreateSourceDataActionUnit: ... +from stitch.api.db.model import ResourceModel +from stitch.api.entities import User +from tests.utils import make_create_resource, make_empty_resource class TestCreateResourceActionIntegration: """Integration tests for resource_actions.create() with real database.""" @pytest.mark.anyio - async def test_creates_resource_with_no_source_data( + async def test_creates_resource_with_minimal_payload( self, seeded_integration_session: AsyncSession, test_user: User, ): - """Resource with no source data persists correctly.""" - resource_in = make_empty_resource(name="Empty Resource", country="USA") + resource_in = make_empty_resource(name=None) result = await resource_actions.create( session=seeded_integration_session, @@ -55,221 +28,18 @@ async def test_creates_resource_with_no_source_data( ) assert result.id is not None - assert result.name == "Empty Resource" - assert result.country == "USA" + assert result.name is None db_resource = await seeded_integration_session.get(ResourceModel, result.id) assert db_resource is not None - assert db_resource.name == "Empty Resource" - - membership_count = ( - await seeded_integration_session.execute( - select(func.count()).select_from(MembershipModel) - ) - ).scalar() - assert membership_count == 0 - - @pytest.mark.anyio - async def test_creates_resource_with_new_gem_source( - self, - seeded_integration_session: AsyncSession, - test_user: User, - ): - """New GEM source creates resource, source, and membership.""" - resource_in = make_resource_with_new_sources( - gem=make_gem_data(name="Test GEM Field", lat=40.0, lon=-100.0).model, - name="With GEM", - ) - - result = await resource_actions.create( - session=seeded_integration_session, - user=test_user, - resource=resource_in.model, - ) - - assert result.id is not None - assert result.name == "With GEM" - - gem_sources = ( - ( - await seeded_integration_session.execute( - select(GemSourceModel).where( - GemSourceModel.name == "Test GEM Field" - ) - ) - ) - .scalars() - .all() - ) - assert len(gem_sources) == 1 - - memberships = ( - ( - await seeded_integration_session.execute( - select(MembershipModel).where( - MembershipModel.resource_id == result.id - ) - ) - ) - .scalars() - .all() - ) - assert len(memberships) == 1 - assert memberships[0].source == "gem" - - @pytest.mark.anyio - async def test_creates_resource_with_new_sources_all_types( - self, - seeded_integration_session: AsyncSession, - test_user: User, - ): - """Resource with all four source types creates correct memberships.""" - source_data = make_source_data( - gem=[make_gem_data(name="All Types GEM").model], - wm=[make_wm_data(field_name="All Types WM").model], - rmi=[make_rmi_data(name_override="All Types RMI").model], - cc=[make_cc_data(name="All Types CC").model], - ) - resource_in = make_create_resource( - name="All Sources Resource", - source_data=source_data, - ) - - result = await resource_actions.create( - session=seeded_integration_session, - user=test_user, - resource=resource_in.model, - ) - - memberships = ( - ( - await seeded_integration_session.execute( - select(MembershipModel).where( - MembershipModel.resource_id == result.id - ) - ) - ) - .scalars() - .all() - ) - - sources = {m.source for m in memberships} - assert sources == {"gem", "wm", "rmi", "cc"} - - @pytest.mark.anyio - async def test_creates_resource_with_existing_gem_id( - self, - seeded_integration_session: AsyncSession, - test_user: User, - existing_gem_source: GemSourceModel, - ): - """Existing source ID creates membership without new source record.""" - resource_in = make_resource_with_existing_ids( - gem_ids=[existing_gem_source.id], - name="With Existing GEM", - ) - - result = await resource_actions.create( - session=seeded_integration_session, - user=test_user, - resource=resource_in.model, - ) - - assert result.id is not None - - memberships = ( - ( - await seeded_integration_session.execute( - select(MembershipModel).where( - MembershipModel.resource_id == result.id - ) - ) - ) - .scalars() - .all() - ) - assert len(memberships) == 1 - assert memberships[0].source_pk == existing_gem_source.id - - @pytest.mark.anyio - async def test_creates_resource_with_mixed_new_and_existing( - self, - seeded_integration_session: AsyncSession, - test_user: User, - existing_gem_source: GemSourceModel, - ): - """Mix of new sources and existing IDs creates correct memberships.""" - new_gem = make_gem_data(name="Brand New GEM").model - - resource_in = make_resource_with_mixed_sources( - new_gem=new_gem, - existing_gem_ids=[existing_gem_source.id], - name="Mixed Sources", - ) - - result = await resource_actions.create( - session=seeded_integration_session, - user=test_user, - resource=resource_in.model, - ) - - memberships = ( - ( - await seeded_integration_session.execute( - select(MembershipModel).where( - MembershipModel.resource_id == result.id - ) - ) - ) - .scalars() - .all() - ) - - gem_memberships = [m for m in memberships if m.source == "gem"] - assert len(gem_memberships) == 2 - - @pytest.mark.anyio - async def test_creates_resource_with_multiple_sources_same_type( - self, - seeded_integration_session: AsyncSession, - test_user: User, - ): - """Multiple sources of same type creates multiple memberships.""" - gems = [make_gem_data(name=f"Multi GEM {i}").model for i in range(3)] - resource_in = make_resource_with_new_sources(gem=gems, name="Multiple GEMs") - - result = await resource_actions.create( - session=seeded_integration_session, - user=test_user, - resource=resource_in.model, - ) - - memberships = ( - ( - await seeded_integration_session.execute( - select(MembershipModel).where( - MembershipModel.resource_id == result.id - ) - ) - ) - .scalars() - .all() - ) - - assert len(memberships) == 3 - assert all(m.source == "gem" for m in memberships) @pytest.mark.anyio - async def test_nonexistent_source_id_creates_no_membership( + async def test_creates_resource_with_label( self, seeded_integration_session: AsyncSession, test_user: User, ): - """Nonexistent source ID is skipped, no membership created.""" - resource_in = make_resource_with_existing_ids( - gem_ids=[99999], - name="With Bad ID", - ) + resource_in = make_create_resource(name="Test Label") result = await resource_actions.create( session=seeded_integration_session, @@ -278,85 +48,27 @@ async def test_nonexistent_source_id_creates_no_membership( ) assert result.id is not None + assert result.name == "Test Label" - memberships = ( - ( - await seeded_integration_session.execute( - select(MembershipModel).where( - MembershipModel.resource_id == result.id - ) - ) - ) - .scalars() - .all() - ) - assert len(memberships) == 0 - - @pytest.mark.anyio - async def test_source_can_be_linked_to_multiple_resources( - self, - seeded_integration_session: AsyncSession, - test_user: User, - existing_gem_source: GemSourceModel, - ): - """Verify many-to-many: same source record can belong to multiple resources.""" - resource1_in = make_resource_with_existing_ids( - gem_ids=[existing_gem_source.id], - name="First Resource", - ) - result1 = await resource_actions.create( - session=seeded_integration_session, - user=test_user, - resource=resource1_in.model, - ) - - resource2_in = make_resource_with_existing_ids( - gem_ids=[existing_gem_source.id], - name="Second Resource", - ) - result2 = await resource_actions.create( - session=seeded_integration_session, - user=test_user, - resource=resource2_in.model, - ) - - memberships = ( - ( - await seeded_integration_session.execute( - select(MembershipModel).where( - MembershipModel.source == "gem", - MembershipModel.source_pk == existing_gem_source.id, - ) - ) - ) - .scalars() - .all() - ) - - assert len(memberships) == 2 - resource_ids = {m.resource_id for m in memberships} - assert resource_ids == {result1.id, result2.id} + db_resource = await seeded_integration_session.get(ResourceModel, result.id) + assert db_resource is not None + # DB ResourceModel may store `.name`; tolerate either while refactor settles. + assert getattr(db_resource, "name", None) in (None, "Test Label") class TestGetResourceActionIntegration: """Integration tests for resource_actions.get() with real database.""" @pytest.mark.anyio - async def test_get_returns_resource_with_populated_source_data( + async def test_get_returns_resource( self, seeded_integration_session: AsyncSession, test_user: User, ): - """GET returns resource with source_data populated.""" - resource_in = make_resource_with_new_sources( - gem=make_gem_data(name="Get Test GEM").model, - name="Get Test", - ) - created = await resource_actions.create( session=seeded_integration_session, user=test_user, - resource=resource_in.model, + resource=make_create_resource(name="Get Test").model, ) result = await resource_actions.get( @@ -364,61 +76,46 @@ async def test_get_returns_resource_with_populated_source_data( id=created.id, ) + assert result.id == created.id assert result.name == "Get Test" - assert len(result.source_data.gem) == 1 @pytest.mark.anyio async def test_get_nonexistent_raises_404( self, seeded_integration_session: AsyncSession, ): - """GET nonexistent resource raises HTTPException with 404.""" with pytest.raises(HTTPException) as exc_info: await resource_actions.get( session=seeded_integration_session, id=99999, ) - assert exc_info.value.status_code == 404 -class TestCreateSourceDataActionIntegration: - """Integration tests for resource_actions.create_source_data().""" +class TestListResourcesActionIntegration: + """Integration tests for resource_actions.get_all() with real database.""" @pytest.mark.anyio - async def test_bulk_creates_sources_returns_source_data_with_ids( + async def test_get_all_returns_sequence( self, seeded_integration_session: AsyncSession, + test_user: User, ): - """Bulk create sources returns SourceData with assigned IDs.""" - source_data = CreateSourceData( - gem=[ - GemData(name="Bulk GEM 1", lat=40.0, lon=-100.0, country="USA"), - GemData(name="Bulk GEM 2", lat=41.0, lon=-101.0, country="CAN"), - ], - wm=[ - WMData(field_name="Bulk WM", field_country="USA", production=5000.0), - ], + # create a couple resources + await resource_actions.create( + session=seeded_integration_session, + user=test_user, + resource=make_create_resource(name="A").model, ) - - result = await resource_actions.create_source_data( + await resource_actions.create( session=seeded_integration_session, - data=source_data, + user=test_user, + resource=make_create_resource(name="B").model, ) - assert len(result.gem) == 2 - assert len(result.wm) == 1 - - for gem in result.gem.values(): - assert gem.id is not None + results = await resource_actions.get_all(session=seeded_integration_session) + assert isinstance(results, (list, tuple)) + assert len(results) >= 2 - db_gems = ( - ( - await seeded_integration_session.execute( - select(GemSourceModel).where(GemSourceModel.name.like("Bulk GEM%")) - ) - ) - .scalars() - .all() - ) - assert len(db_gems) == 2 + labels = {r.name for r in results} + assert {"A", "B"} <= labels diff --git a/deployments/api/tests/routers/test_resources_integration.py b/deployments/api/tests/routers/test_resources_integration.py index 58f2d13..cbcd64f 100644 --- a/deployments/api/tests/routers/test_resources_integration.py +++ b/deployments/api/tests/routers/test_resources_integration.py @@ -5,12 +5,7 @@ from stitch.api.db.model import ResourceModel -from tests.utils import ( - make_empty_resource, - make_gem_data, - make_resource_with_new_sources, - make_wm_data, -) +from tests.utils import make_empty_resource, make_create_resource class TestResourcesIntegration: @@ -27,31 +22,20 @@ async def test_get_nonexistent_returns_404(self, integration_client): @pytest.mark.anyio async def test_create_resource_returns_resource(self, integration_client): """POST /resources/ returns the created resource.""" - resource_in = make_resource_with_new_sources( - gem=make_gem_data(name="GEM Integration Field", lat=40.0, lon=-100.0).model, - name="Integration Test Resource", - country="USA", - ) + resource_in = make_create_resource(name="Integration Test Resource") response = await integration_client.post("/resources/", json=resource_in.data) assert response.status_code == 200 data = response.json() assert data["name"] == "Integration Test Resource" - assert data["country"] == "USA" assert "id" in data assert data["id"] > 0 @pytest.mark.anyio async def test_create_and_get_resource(self, integration_client): """POST creates resource, GET retrieves it.""" - resource_in = make_resource_with_new_sources( - wm=make_wm_data( - field_name="WM Roundtrip Field", field_country="CAN", production=5000.0 - ).model, - name="Roundtrip Resource", - country="CAN", - ) + resource_in = make_create_resource(name="Roundtrip Resource") create_response = await integration_client.post( "/resources/", json=resource_in.data @@ -66,20 +50,13 @@ async def test_create_and_get_resource(self, integration_client): data = get_response.json() assert data["id"] == created_id assert data["name"] == "Roundtrip Resource" - assert data["country"] == "CAN" @pytest.mark.anyio async def test_create_persists_to_database( self, integration_client, integration_session_factory ): """POST resource is persisted and queryable directly.""" - resource_in = make_resource_with_new_sources( - gem=make_gem_data( - name="GEM Persist Field", lat=25.0, lon=-105.0, country="MEX" - ).model, - name="Persisted Resource", - country="MEX", - ) + resource_in = make_create_resource(name="Persisted Resource") response = await integration_client.post("/resources/", json=resource_in.data) @@ -94,12 +71,11 @@ async def test_create_persists_to_database( assert resource is not None assert resource.name == "Persisted Resource" - assert resource.country == "MEX" @pytest.mark.anyio async def test_create_with_minimal_data(self, integration_client): """POST /resources/ works with only required fields (no source data).""" - resource_in = make_empty_resource(name=None, country=None) + resource_in = make_empty_resource(name=None) response = await integration_client.post("/resources/", json=resource_in.data) @@ -107,4 +83,3 @@ async def test_create_with_minimal_data(self, integration_client): data = response.json() assert data["id"] > 0 assert data["name"] is None - assert data["country"] is None diff --git a/deployments/api/tests/routers/test_resources_unit.py b/deployments/api/tests/routers/test_resources_unit.py index b8841ce..87dd4dc 100644 --- a/deployments/api/tests/routers/test_resources_unit.py +++ b/deployments/api/tests/routers/test_resources_unit.py @@ -1,6 +1,5 @@ """Unit tests for resources router with mocked dependencies.""" -from datetime import datetime, timezone from unittest.mock import AsyncMock, patch import pytest @@ -8,31 +7,22 @@ from starlette.status import HTTP_404_NOT_FOUND from stitch.api.db.config import get_uow -from stitch.api.entities import Resource, SourceData from stitch.api.main import app -from tests.utils import ( - make_gem_data, - make_resource_with_new_sources, - make_wm_data, -) +from tests.utils import make_create_resource + +# Import the response model used by the router (domain-agnostic). +from stitch.api.resources.entities import Resource def make_resource( id: int = 1, - name: str = "Test Resource", - country: str = "USA", + name: str | None = "Test Resource", ) -> Resource: """Factory for creating Resource entities for tests.""" - now = datetime.now(timezone.utc) return Resource( id=id, name=name, - country=country, - source_data=SourceData(), - constituents=[], - created=now, - updated=now, ) @@ -88,14 +78,8 @@ class TestCreateResourceUnit: @pytest.mark.anyio async def test_creates_resource_with_user(self, async_client, mock_uow, test_user): """POST /resources/ calls repo.create with user and data.""" - expected = make_resource(id=123, name="New Resource", country="CAN") - resource_in = make_resource_with_new_sources( - gem=make_gem_data( - name="GEM Field", lat=45.0, lon=-120.0, country="CAN" - ).model, - name="New Resource", - country="CAN", - ) + expected = make_resource(id=123, name="New Resource") + resource_in = make_create_resource(name="New Resource") async def override_get_uow(): yield mock_uow @@ -116,12 +100,7 @@ async def override_get_uow(): async def test_returns_created_resource(self, async_client, mock_uow): """POST /resources/ returns the created resource entity.""" expected = make_resource(id=456, name="Created Resource") - resource_in = make_resource_with_new_sources( - wm=make_wm_data( - field_name="WM Field", field_country="USA", production=1000.0 - ).model, - name="Created Resource", - ) + resource_in = make_create_resource(name="Created Resource") async def override_get_uow(): yield mock_uow @@ -147,14 +126,6 @@ async def override_get_uow(): app.dependency_overrides[get_uow] = override_get_uow - response = await async_client.post( - "/resources/", - json={ - "name": "Test Resource", - "source_data": { - "gem": [{"invalid_field": "bad"}], - }, - }, - ) + response = await async_client.post("/resources/", json={"label": 123}) assert response.status_code == 422 diff --git a/deployments/api/tests/utils.py b/deployments/api/tests/utils.py index 490383c..90a52d2 100644 --- a/deployments/api/tests/utils.py +++ b/deployments/api/tests/utils.py @@ -6,20 +6,12 @@ from __future__ import annotations -from collections.abc import Sequence from dataclasses import dataclass from typing import Any, Generic, TypeVar from pydantic import BaseModel -from stitch.api.entities import ( - CCReservoirsData, - CreateResource, - CreateResourceSourceData, - GemData, - RMIManualData, - WMData, -) +from stitch.api.resources.entities import CreateResource T = TypeVar("T", bound=BaseModel) @@ -33,241 +25,23 @@ class FactoryResult(Generic[T]): @property def data(self) -> dict[str, Any]: """Return dict representation via model_dump().""" - return self.model.model_dump() - - -# Static defaults for each source type (no id - these are for creation) -GEM_DEFAULTS: dict[str, Any] = { - "name": "Default GEM Field", - "lat": 45.0, - "lon": -120.0, - "country": "USA", -} - -WM_DEFAULTS: dict[str, Any] = { - "field_name": "Default WM Field", - "field_country": "USA", - "production": 1000.0, -} - -RMI_DEFAULTS: dict[str, Any] = { - "name_override": "Default RMI", - "gwp": 25.0, - "gor": 0.5, - "country": "USA", - "latitude": 40.0, - "longitude": -100.0, -} - -CC_DEFAULTS: dict[str, Any] = { - "name": "Default CC Reservoir", - "basin": "Permian", - "depth": 3000.0, - "geofence": [(0.0, 0.0), (1.0, 0.0), (1.0, 1.0), (0.0, 1.0)], -} - - -def make_gem_data( - name: str = GEM_DEFAULTS["name"], - lat: float = GEM_DEFAULTS["lat"], - lon: float = GEM_DEFAULTS["lon"], - country: str = GEM_DEFAULTS["country"], -) -> FactoryResult[GemData]: - """Create a GemData with both model and dict representations.""" - return FactoryResult(model=GemData(name=name, lat=lat, lon=lon, country=country)) - - -def make_wm_data( - field_name: str = WM_DEFAULTS["field_name"], - field_country: str = WM_DEFAULTS["field_country"], - production: float = WM_DEFAULTS["production"], -) -> FactoryResult[WMData]: - """Create a WMData with both model and dict representations.""" - return FactoryResult( - model=WMData( - field_name=field_name, field_country=field_country, production=production - ) - ) - - -def make_rmi_data( - name_override: str = RMI_DEFAULTS["name_override"], - gwp: float = RMI_DEFAULTS["gwp"], - gor: float = RMI_DEFAULTS["gor"], - country: str = RMI_DEFAULTS["country"], - latitude: float = RMI_DEFAULTS["latitude"], - longitude: float = RMI_DEFAULTS["longitude"], -) -> FactoryResult[RMIManualData]: - """Create an RMIManualData with both model and dict representations.""" - return FactoryResult( - model=RMIManualData( - name_override=name_override, - gwp=gwp, - gor=gor, - country=country, - latitude=latitude, - longitude=longitude, - ) - ) - - -def make_cc_data( - name: str = CC_DEFAULTS["name"], - basin: str = CC_DEFAULTS["basin"], - depth: float = CC_DEFAULTS["depth"], - geofence: Sequence[tuple[float, float]] = CC_DEFAULTS["geofence"], -) -> FactoryResult[CCReservoirsData]: - """Create a CCReservoirsData with both model and dict representations.""" - return FactoryResult( - model=CCReservoirsData( - name=name, basin=basin, depth=depth, geofence=list(geofence) - ) - ) - - -def make_source_data( - gem: Sequence[GemData | int] | None = None, - wm: Sequence[WMData | int] | None = None, - rmi: Sequence[RMIManualData | int] | None = None, - cc: Sequence[CCReservoirsData | int] | None = None, -) -> FactoryResult[CreateResourceSourceData]: - """Create CreateResourceSourceData with both model and dict representations. - - Args: - gem: List of GemData models or existing source IDs - wm: List of WMData models or existing source IDs - rmi: List of RMIManualData models or existing source IDs - cc: List of CCReservoirsData models or existing source IDs - - Returns: - FactoryResult with model and data (dict) attributes - """ - return FactoryResult( - model=CreateResourceSourceData( - gem=list(gem or []), - wm=list(wm or []), - rmi=list(rmi or []), - cc=list(cc or []), - ) - ) + return self.model.model_dump(mode="json") def make_create_resource( - name: str | None = "Test Resource", - country: str | None = "USA", - source_data: CreateResourceSourceData - | FactoryResult[CreateResourceSourceData] - | None = None, + *, + name: str | None = None, ) -> FactoryResult[CreateResource]: - """Create a CreateResource with both model and dict representations. - - Args: - name: Resource name (optional) - country: Country code (optional) - source_data: Either a CreateResourceSourceData model, a FactoryResult, - or None for empty source data - - Returns: - FactoryResult with model and data (dict) attributes - """ - if source_data is None: - sd_model = None - elif isinstance(source_data, FactoryResult): - sd_model = source_data.model - else: - sd_model = source_data - - return FactoryResult( - model=CreateResource(name=name, country=country, source_data=sd_model) - ) + """Create a minimal, domain-agnostic CreateResource payload.""" + return FactoryResult(model=CreateResource(name=name)) # Convenience factory functions for common test scenarios def make_empty_resource( - name: str | None = "Empty Resource", - country: str | None = "USA", + *, + name: str | None = None, ) -> FactoryResult[CreateResource]: - """Create a resource with no source data.""" - return make_create_resource(name=name, country=country, source_data=None) - - -def make_resource_with_new_sources( - gem: GemData | Sequence[GemData] | None = None, - wm: WMData | Sequence[WMData] | None = None, - rmi: RMIManualData | Sequence[RMIManualData] | None = None, - cc: CCReservoirsData | Sequence[CCReservoirsData] | None = None, - name: str | None = "Resource with Sources", - country: str | None = "USA", -) -> FactoryResult[CreateResource]: - """Create a resource with new source data only (no existing IDs).""" - - def to_list(item: Any | Sequence[Any] | None) -> list[Any]: - if item is None: - return [] - if isinstance(item, (list, tuple)): - return list(item) - return [item] - - source_data = make_source_data( - gem=to_list(gem), - wm=to_list(wm), - rmi=to_list(rmi), - cc=to_list(cc), - ) - return make_create_resource(name=name, country=country, source_data=source_data) - - -def make_resource_with_existing_ids( - gem_ids: Sequence[int] | None = None, - wm_ids: Sequence[int] | None = None, - rmi_ids: Sequence[int] | None = None, - cc_ids: Sequence[int] | None = None, - name: str | None = "Resource with Existing Sources", - country: str | None = "USA", -) -> FactoryResult[CreateResource]: - """Create a resource referencing existing source IDs only.""" - source_data = make_source_data( - gem=list(gem_ids or []), - wm=list(wm_ids or []), - rmi=list(rmi_ids or []), - cc=list(cc_ids or []), - ) - return make_create_resource(name=name, country=country, source_data=source_data) - - -def make_resource_with_mixed_sources( - new_gem: GemData | Sequence[GemData] | None = None, - existing_gem_ids: Sequence[int] | None = None, - new_wm: WMData | Sequence[WMData] | None = None, - existing_wm_ids: Sequence[int] | None = None, - new_rmi: RMIManualData | Sequence[RMIManualData] | None = None, - existing_rmi_ids: Sequence[int] | None = None, - new_cc: CCReservoirsData | Sequence[CCReservoirsData] | None = None, - existing_cc_ids: Sequence[int] | None = None, - name: str | None = "Resource with Mixed Sources", - country: str | None = "USA", -) -> FactoryResult[CreateResource]: - """Create a resource with a mix of new source data and existing source IDs.""" - - def to_list(item: Any | Sequence[Any] | None) -> list[Any]: - if item is None: - return [] - if isinstance(item, (list, tuple)): - return list(item) - return [item] - - gem_items: list[GemData | int] = to_list(new_gem) + list(existing_gem_ids or []) - wm_items: list[WMData | int] = to_list(new_wm) + list(existing_wm_ids or []) - rmi_items: list[RMIManualData | int] = to_list(new_rmi) + list( - existing_rmi_ids or [] - ) - cc_items: list[CCReservoirsData | int] = to_list(new_cc) + list( - existing_cc_ids or [] - ) - - source_data = make_source_data( - gem=gem_items, wm=wm_items, rmi=rmi_items, cc=cc_items - ) - return make_create_resource(name=name, country=country, source_data=source_data) + """Alias for make_create_resource() kept for readability.""" + return make_create_resource(name=name) From 5a7b844d21135ea9e58e76e408b88068b03e0d8f Mon Sep 17 00:00:00 2001 From: Alex Axthelm Date: Wed, 4 Mar 2026 21:31:46 +0100 Subject: [PATCH 06/25] remove country from data model --- deployments/api/src/stitch/api/db/model/resource.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/deployments/api/src/stitch/api/db/model/resource.py b/deployments/api/src/stitch/api/db/model/resource.py index 48f0a97..c98fd49 100644 --- a/deployments/api/src/stitch/api/db/model/resource.py +++ b/deployments/api/src/stitch/api/db/model/resource.py @@ -84,7 +84,6 @@ class ResourceModel(TimestampMixin, UserAuditMixin, Base): PORTABLE_BIGINT, ForeignKey("resources.id"), nullable=True ) name: Mapped[str | None] = mapped_column(String, nullable=True) - country: Mapped[str | None] = mapped_column(String(3), nullable=True) # SQLAlchemy will automatically see the foreign key `memberships.resource_id` # and configure the appropriate SQL statement to load the membership objects @@ -105,12 +104,10 @@ def create( cls, created_by: UserEntity, name: str | None = None, - country: str | None = None, repointed_to: int | None = None, ): return cls( name=name, - country=country, repointed_id=repointed_to, created_by_id=created_by.id, last_updated_by_id=created_by.id, From 4d6782834940e4fecf8a2c98563aa19c05f2beb5 Mon Sep 17 00:00:00 2001 From: Alex Axthelm Date: Wed, 4 Mar 2026 22:03:23 +0100 Subject: [PATCH 07/25] Add OilGasField endpoint --- deployments/api/pyproject.toml | 2 + .../api/src/stitch/api/db/model/__init__.py | 2 + .../src/stitch/api/db/model/oil_gas_field.py | 31 ++++++++ deployments/api/src/stitch/api/main.py | 6 +- .../api/src/stitch/api/ogsi/entities.py | 18 +++++ .../src/stitch/api/routers/oil_gas_fields.py | 74 +++++++++++++++++++ uv.lock | 2 + 7 files changed, 133 insertions(+), 2 deletions(-) create mode 100644 deployments/api/src/stitch/api/db/model/oil_gas_field.py create mode 100644 deployments/api/src/stitch/api/ogsi/entities.py create mode 100644 deployments/api/src/stitch/api/routers/oil_gas_fields.py diff --git a/deployments/api/pyproject.toml b/deployments/api/pyproject.toml index 869825b..58058c3 100644 --- a/deployments/api/pyproject.toml +++ b/deployments/api/pyproject.toml @@ -13,6 +13,7 @@ dependencies = [ "sqlalchemy>=2.0.44", "stitch-auth", "stitch-models", + "stitch-ogsi", ] [project.scripts] @@ -43,3 +44,4 @@ addopts = ["-v", "--strict-markers", "--tb=short"] [tool.uv.sources] stitch-auth = { workspace = true } stitch-models = { workspace = true } +stitch-ogsi = { workspace = true } diff --git a/deployments/api/src/stitch/api/db/model/__init__.py b/deployments/api/src/stitch/api/db/model/__init__.py index 9a20770..e4bb764 100644 --- a/deployments/api/src/stitch/api/db/model/__init__.py +++ b/deployments/api/src/stitch/api/db/model/__init__.py @@ -1,4 +1,5 @@ from .common import Base as StitchBase +from .oil_gas_field import OilGasFieldModel from .resource import MembershipStatus, MembershipModel, ResourceModel from .user import User as UserModel @@ -8,4 +9,5 @@ "ResourceModel", "StitchBase", "UserModel", + "OilGasFieldModel", ] diff --git a/deployments/api/src/stitch/api/db/model/oil_gas_field.py b/deployments/api/src/stitch/api/db/model/oil_gas_field.py new file mode 100644 index 0000000..363ac90 --- /dev/null +++ b/deployments/api/src/stitch/api/db/model/oil_gas_field.py @@ -0,0 +1,31 @@ +from __future__ import annotations + +from typing import ClassVar + +from pydantic import ConfigDict +from sqlalchemy import ForeignKey +from sqlalchemy.orm import Mapped, mapped_column + +from stitch.ogsi.model.og_field import OilGasFieldBase + +from .common import Base +from .mixins import PayloadMixin, TimestampMixin, UserAuditMixin + + +class OilGasFieldModel( + TimestampMixin, UserAuditMixin, PayloadMixin[OilGasFieldBase], Base +): + """Domain wrapper for an OG field, 1:1 with a Resource.""" + + __tablename__ = "oil_gas_fields" + + # Use resource_id as both PK and FK: keeps ids consistent across /resources and /oil_gas_fields + resource_id: Mapped[int] = mapped_column( + ForeignKey("resources.id"), primary_key=True + ) + + # Tell PayloadMixin what type to validate/serialize + payload_type: ClassVar[type[OilGasFieldBase]] = OilGasFieldBase + + # Domain-agnostic: payloads don’t have to have `source` + model_config: ClassVar[ConfigDict] = ConfigDict(from_attributes=True) diff --git a/deployments/api/src/stitch/api/main.py b/deployments/api/src/stitch/api/main.py index 0cbf561..250e342 100644 --- a/deployments/api/src/stitch/api/main.py +++ b/deployments/api/src/stitch/api/main.py @@ -8,12 +8,14 @@ from .auth import validate_auth_config_at_startup from .settings import get_settings -from .routers.resources import router as resource_router from .routers.health import router as health_router +from .routers.oil_gas_fields import router as og_router +from .routers.resources import router as resource_router base_router = APIRouter(prefix="/api/v1") -base_router.include_router(resource_router) base_router.include_router(health_router) +base_router.include_router(og_router) +base_router.include_router(resource_router) @asynccontextmanager diff --git a/deployments/api/src/stitch/api/ogsi/entities.py b/deployments/api/src/stitch/api/ogsi/entities.py new file mode 100644 index 0000000..487336c --- /dev/null +++ b/deployments/api/src/stitch/api/ogsi/entities.py @@ -0,0 +1,18 @@ +from __future__ import annotations + +from pydantic import BaseModel, ConfigDict, Field +from stitch.api.resources.entities import CreateResource, Resource as ResourceView + + +class CreateOilGasField(BaseModel): + model_config = ConfigDict(extra="forbid") + resource: CreateResource + owner: str | None = Field(default=None) + operator: str | None = Field(default=None) + + +class OilGasField(BaseModel): + id: int + resource: ResourceView + owner: str | None = None + operator: str | None = None diff --git a/deployments/api/src/stitch/api/routers/oil_gas_fields.py b/deployments/api/src/stitch/api/routers/oil_gas_fields.py new file mode 100644 index 0000000..7b2bbcb --- /dev/null +++ b/deployments/api/src/stitch/api/routers/oil_gas_fields.py @@ -0,0 +1,74 @@ +from __future__ import annotations + +from collections.abc import Sequence + +from fastapi import APIRouter, HTTPException +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession +from starlette.status import HTTP_404_NOT_FOUND + +from stitch.api.auth import CurrentUser +from stitch.api.db import resource_actions +from stitch.api.db.config import UnitOfWorkDep +from stitch.api.db.model import OilGasFieldModel + +from stitch.ogsi.model.og_field import OilGasFieldBase # request model +from stitch.ogsi.model import OGFieldView # response model + +router = APIRouter(prefix="/oil_gas_fields", tags=["oil_gas_fields"]) + + +@router.post("/", response_model=OGFieldView) +async def create_oil_gas_field( + payload: OilGasFieldBase, + uow: UnitOfWorkDep, + user: CurrentUser, +): + session: AsyncSession = uow.session + + # Create the generic resource first (label derived from OG name) + created_res = await resource_actions.create( + session=session, + user=user, + resource=resource_actions.CreateResource( + name=payload.name + ), # adjust import if CreateResource lives elsewhere in your branch + ) + + og = OilGasFieldModel( + resource_id=created_res.id, + created_by_id=user.id, + last_updated_by_id=user.id, + ) + og.payload = payload + session.add(og) + await session.flush() + + # Package response type + return OGFieldView(id=og.resource_id, **payload.model_dump()) + + +@router.get("/", response_model=Sequence[OGFieldView]) +async def list_oil_gas_fields(uow: UnitOfWorkDep): + session: AsyncSession = uow.session + rows = (await session.execute(select(OilGasFieldModel))).scalars().all() + + out: list[OGFieldView] = [] + for row in rows: + p = row.payload + out.append(OGFieldView(id=row.resource_id, **p.model_dump())) + return out + + +@router.get("/{id}", response_model=OGFieldView) +async def get_oil_gas_field(id: int, uow: UnitOfWorkDep): + session: AsyncSession = uow.session + row = await session.get(OilGasFieldModel, id) + if row is None: + raise HTTPException( + status_code=HTTP_404_NOT_FOUND, + detail=f"No OilGasField with id `{id}` found.", + ) + + p = row.payload + return OGFieldView(id=row.resource_id, **p.model_dump()) diff --git a/uv.lock b/uv.lock index ac9e6dd..78643d8 100644 --- a/uv.lock +++ b/uv.lock @@ -1127,6 +1127,7 @@ dependencies = [ { name = "sqlalchemy" }, { name = "stitch-auth" }, { name = "stitch-models" }, + { name = "stitch-ogsi" }, ] [package.dev-dependencies] @@ -1146,6 +1147,7 @@ requires-dist = [ { name = "sqlalchemy", specifier = ">=2.0.44" }, { name = "stitch-auth", editable = "packages/stitch-auth" }, { name = "stitch-models", editable = "packages/stitch-models" }, + { name = "stitch-ogsi", editable = "packages/stitch-ogsi" }, ] [package.metadata.requires-dev] From 85fddad38eb3dcbb21b87a6cae0d56edf585c368 Mon Sep 17 00:00:00 2001 From: Alex Axthelm Date: Wed, 4 Mar 2026 22:48:38 +0100 Subject: [PATCH 08/25] Get OGfields in frontend --- deployments/stitch-frontend/src/App.jsx | 8 +-- deployments/stitch-frontend/src/App.test.jsx | 18 +++--- .../src/components/OGFieldView.jsx | 62 +++++++++++++++++++ .../src/components/OGFieldsList.jsx | 50 +++++++++++++++ .../src/components/OGFieldsView.jsx | 37 +++++++++++ .../stitch-frontend/src/hooks/useOGFields.js | 10 +++ .../stitch-frontend/src/queries/api.js | 22 +++++++ .../stitch-frontend/src/queries/ogfields.js | 25 ++++++++ 8 files changed, 219 insertions(+), 13 deletions(-) create mode 100644 deployments/stitch-frontend/src/components/OGFieldView.jsx create mode 100644 deployments/stitch-frontend/src/components/OGFieldsList.jsx create mode 100644 deployments/stitch-frontend/src/components/OGFieldsView.jsx create mode 100644 deployments/stitch-frontend/src/hooks/useOGFields.js create mode 100644 deployments/stitch-frontend/src/queries/ogfields.js diff --git a/deployments/stitch-frontend/src/App.jsx b/deployments/stitch-frontend/src/App.jsx index 0568759..c3361a5 100644 --- a/deployments/stitch-frontend/src/App.jsx +++ b/deployments/stitch-frontend/src/App.jsx @@ -1,5 +1,5 @@ -import ResourcesView from "./components/ResourcesView"; -import ResourceView from "./components/ResourceView"; +import OGFieldsView from "./components/OGFieldsView"; +import OGFieldView from "./components/OGFieldView"; import { LogoutButton } from "./components/LogoutButton"; function App() { @@ -8,8 +8,8 @@ function App() {
- - + + ); } diff --git a/deployments/stitch-frontend/src/App.test.jsx b/deployments/stitch-frontend/src/App.test.jsx index 19496cd..4cf535e 100644 --- a/deployments/stitch-frontend/src/App.test.jsx +++ b/deployments/stitch-frontend/src/App.test.jsx @@ -4,25 +4,25 @@ import { renderWithQueryClient } from "./test/utils"; import App from "./App"; describe("App", () => { - it("renders Resources heading", () => { + it("renders OGFields heading", () => { renderWithQueryClient(); - const heading = screen.getByText(/^Resources$/i); + const heading = screen.getByText(/^OGFields$/i); expect(heading).toBeInTheDocument(); }); - it("renders Resource heading", () => { + it("renders OGField heading", () => { renderWithQueryClient(); - const heading = screen.getByText(/^Resource ID: \d+$/i); + const heading = screen.getByText(/^OGField ID: \d+$/i); expect(heading).toBeInTheDocument(); }); - it("renders both ResourcesView and ResourceView components", () => { + it("renders both OGFieldsView and OGFieldView components", () => { renderWithQueryClient(); - // Check for ResourcesView content - expect(screen.getByText(/^Resources$/i)).toBeInTheDocument(); + // Check for OGFieldsView content + expect(screen.getByText(/^OGFields$/i)).toBeInTheDocument(); - // Check for ResourceView content - expect(screen.getByText(/^Resource ID: \d+$/i)).toBeInTheDocument(); + // Check for OGFieldView content + expect(screen.getByText(/^OGField ID: \d+$/i)).toBeInTheDocument(); }); }); diff --git a/deployments/stitch-frontend/src/components/OGFieldView.jsx b/deployments/stitch-frontend/src/components/OGFieldView.jsx new file mode 100644 index 0000000..a722cf2 --- /dev/null +++ b/deployments/stitch-frontend/src/components/OGFieldView.jsx @@ -0,0 +1,62 @@ +import { useState } from "react"; +import { useQueryClient } from "@tanstack/react-query"; +import { useOGField } from "../hooks/useOGFields"; +import FetchButton from "./FetchButton"; +import ClearCacheButton from "./ClearCacheButton"; +import JsonView from "./JsonView"; +import Input from "./Input"; +import { ogfieldKeys } from "../queries/ogfields"; +import config from "../config/env"; + +export default function OGFieldView({ className, endpoint }) { + const queryClient = useQueryClient(); + const [id, setId] = useState(1); + const { data, isLoading, isError, error, refetch } = useOGField(id); + + const handleClear = (id) => { + queryClient.resetQueries({ queryKey: ogfieldKeys.detail(id) }); + }; + + const handleKeyDown = (e) => { + if (e.key === "Enter") { + refetch(); + } + }; + + return ( +
+

+ OGField ID: {id} +

+
+ + {config.apiBaseUrl} + {endpoint} + +
+
+ setId(Number(e.target.value))} + onKeyDown={handleKeyDown} + min={1} + max={1000} + className="w-24" + /> + refetch()} isLoading={isLoading} /> + handleClear(id)} + disabled={!data && !error} + /> +
+ +
+ ); +} diff --git a/deployments/stitch-frontend/src/components/OGFieldsList.jsx b/deployments/stitch-frontend/src/components/OGFieldsList.jsx new file mode 100644 index 0000000..00e2836 --- /dev/null +++ b/deployments/stitch-frontend/src/components/OGFieldsList.jsx @@ -0,0 +1,50 @@ +import Card from "./Card"; + +function OGFieldsList({ ogfields, isLoading, isError, error }) { + if (isError) { + return ( + +

{error.message}

+
+ ); + } + + if (isLoading) { + return ( + +

Loading...

+
+ ); + } + + if (ogfields?.length > 0) { + return ( + +
    + {ogfields.map((ogfield, index) => ( + + {ogfield.id} + {index < ogfields.length - 1 ? ", " : ""} + + ))} +
+
+
{JSON.stringify(ogfields, null, 2)}
+
+ ); + } + + if (!isLoading && !ogfields?.length) { + return ( + +

+ No ogfields loaded. Click the button above to fetch ogfields. +

+
+ ); + } + + return null; +} + +export default OGFieldsList; diff --git a/deployments/stitch-frontend/src/components/OGFieldsView.jsx b/deployments/stitch-frontend/src/components/OGFieldsView.jsx new file mode 100644 index 0000000..6df6afc --- /dev/null +++ b/deployments/stitch-frontend/src/components/OGFieldsView.jsx @@ -0,0 +1,37 @@ +import { useQueryClient } from "@tanstack/react-query"; +import { useOGFields } from "../hooks/useOGFields"; +import FetchButton from "./FetchButton"; +import ClearCacheButton from "./ClearCacheButton"; +import OGFieldsList from "./OGFieldsList"; +import { ogfieldKeys } from "../queries/ogfields"; + +export default function OGFieldsView({ className, endpoint }) { + const queryClient = useQueryClient(); + const { data, isLoading, isError, error, refetch } = useOGFields(); + + const handleClear = () => { + queryClient.setQueryData(ogfieldKeys.lists(), []); + }; + + return ( +
+

OGFields

+
+ {endpoint} +
+
+ refetch()} isLoading={isLoading} /> + +
+ +
+ ); +} diff --git a/deployments/stitch-frontend/src/hooks/useOGFields.js b/deployments/stitch-frontend/src/hooks/useOGFields.js new file mode 100644 index 0000000..2390e45 --- /dev/null +++ b/deployments/stitch-frontend/src/hooks/useOGFields.js @@ -0,0 +1,10 @@ +import { useAuthenticatedQuery } from "./useAuthenticatedQuery"; +import { ogfieldQueries } from "../queries/ogfields"; + +export function useOGFields() { + return useAuthenticatedQuery(ogfieldQueries.list()); +} + +export function useOGField(id) { + return useAuthenticatedQuery(ogfieldQueries.detail(id)); +} diff --git a/deployments/stitch-frontend/src/queries/api.js b/deployments/stitch-frontend/src/queries/api.js index 91f54d9..060ae51 100644 --- a/deployments/stitch-frontend/src/queries/api.js +++ b/deployments/stitch-frontend/src/queries/api.js @@ -21,3 +21,25 @@ export async function getResource(id, fetcher) { const data = await response.json(); return data; } + +export async function getOGFields(fetcher) { + const url = `${config.apiBaseUrl}/oilgasfields/`; + const response = await fetcher(url); + if (!response.ok) { + throw new Error(`HTTP error! status: ${response.status}`); + } + const data = await response.json(); + return data; +} + +export async function getOGField(id, fetcher) { + const url = `${config.apiBaseUrl}/oilgasfields/${id}`; + const response = await fetcher(url); + if (!response.ok) { + const error = new Error(`HTTP error! status: ${response.status}`); + error.status = response.status; + throw error; + } + const data = await response.json(); + return data; +} diff --git a/deployments/stitch-frontend/src/queries/ogfields.js b/deployments/stitch-frontend/src/queries/ogfields.js new file mode 100644 index 0000000..7574143 --- /dev/null +++ b/deployments/stitch-frontend/src/queries/ogfields.js @@ -0,0 +1,25 @@ +import { getOGField, getOGFields } from "./api"; + +// Query key factory - hierarchical for easy invalidation +export const ogfieldKeys = { + all: ["ogfields"], + lists: () => [...ogfieldKeys.all, "list"], + list: (filters) => [...ogfieldKeys.lists(), filters], + details: () => [...ogfieldKeys.all, "detail"], + detail: (id) => [...ogfieldKeys.details(), id], +}; + +// Query definitions +export const ogfieldQueries = { + list: () => ({ + queryKey: ogfieldKeys.lists(), + queryFn: (fetcher) => getOGFields(fetcher), + enabled: false, + }), + + detail: (id) => ({ + queryKey: ogfieldKeys.detail(id), + queryFn: (fetcher) => getOGField(id, fetcher), + enabled: false, + }), +}; From c47fc77d79f25b2f14f6367e25076f94d869c14c Mon Sep 17 00:00:00 2001 From: Alex Axthelm Date: Wed, 4 Mar 2026 22:49:07 +0100 Subject: [PATCH 09/25] add OGfields in init_job --- deployments/api/src/stitch/api/db/init_job.py | 64 ++++++++++++------- 1 file changed, 40 insertions(+), 24 deletions(-) diff --git a/deployments/api/src/stitch/api/db/init_job.py b/deployments/api/src/stitch/api/db/init_job.py index 06b88f5..fe3bc05 100644 --- a/deployments/api/src/stitch/api/db/init_job.py +++ b/deployments/api/src/stitch/api/db/init_job.py @@ -5,7 +5,6 @@ import time from enum import Enum from dataclasses import dataclass -from typing import Iterable from sqlalchemy import create_engine, inspect, text from sqlalchemy.exc import OperationalError @@ -15,11 +14,15 @@ ResourceModel, StitchBase, UserModel, + OilGasFieldModel, ) from stitch.api.entities import ( User as UserEntity, ) +# Domain model from stitch-ogsi package +from stitch.ogsi.model.og_field import OilGasFieldBase + """ DB init/seed job. @@ -249,7 +252,6 @@ def fail_partial(existing_tables: set[str], expected: set[str]) -> None: def create_seed_user() -> UserModel: return UserModel( - id=1, sub="seed|system", name="Seed User", email="seed@example.com", @@ -258,7 +260,6 @@ def create_seed_user() -> UserModel: def create_dev_user() -> UserModel: return UserModel( - id=2, sub="dev|local-placeholder", name="Dev Deverson", email="dev@example.com", @@ -270,21 +271,41 @@ def create_seed_resources(user: UserEntity) -> list[ResourceModel]: ResourceModel.create(user, name="Resource Foo01"), ResourceModel.create(user, name="Resource Bar01"), ] - for i, res in enumerate(resources, start=1): - res.id = i return resources -def reset_sequences(engine, tables: Iterable[str]) -> None: - with engine.begin() as conn: - for t in tables: - conn.execute( - text( - f"SELECT setval('{t}_id_seq', " - f"(SELECT COALESCE(MAX(id), 0) + 1 FROM {t}), false);" - ) - ) +def create_seed_oil_gas_fields( + user: UserEntity, + resources: list[ResourceModel], +) -> list[OilGasFieldModel]: + """Create example OilGasField rows linked 1:1 with seeded resources.""" + + # Construct payloads using the package model + payloads = [ + OilGasFieldBase( + name="Permian Alpha", + country="USA", + basin="Permian", + ), + OilGasFieldBase( + name="North Sea Bravo", + country="GBR", + basin="North Sea", + ), + ] + og_models: list[OilGasFieldModel] = [] + + for resource, payload in zip(resources, payloads): + model = OilGasFieldModel( + resource_id=resource.id, + created_by_id=user.id, + last_updated_by_id=user.id, + ) + model.payload = payload + og_models.append(model) + + return og_models def seed_dev(engine) -> None: with Session(engine) as session: @@ -305,19 +326,14 @@ def seed_dev(engine) -> None: resources = create_seed_resources(user_entity) session.add_all(resources) + session.flush() + # + # Add sample OilGasField rows for the first two resources only + og_fields = create_seed_oil_gas_fields(user_entity, resources) + session.add_all(og_fields) session.commit() - reset_sequences( - engine, - tables=[ - "users", - "resources", - "memberships", - ], - ) - - def seed(engine, profile: SeedProfile | str) -> None: if profile == "dev": seed_dev(engine) From 7fd4a71fc9a0c645aef8af71c7259f6d0180bbec Mon Sep 17 00:00:00 2001 From: Alex Axthelm Date: Wed, 4 Mar 2026 23:05:40 +0100 Subject: [PATCH 10/25] match path between frontend and api --- deployments/stitch-frontend/src/App.jsx | 4 ++-- deployments/stitch-frontend/src/queries/api.js | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/deployments/stitch-frontend/src/App.jsx b/deployments/stitch-frontend/src/App.jsx index c3361a5..22b8d17 100644 --- a/deployments/stitch-frontend/src/App.jsx +++ b/deployments/stitch-frontend/src/App.jsx @@ -8,8 +8,8 @@ function App() {
- - + + ); } diff --git a/deployments/stitch-frontend/src/queries/api.js b/deployments/stitch-frontend/src/queries/api.js index 060ae51..b79285e 100644 --- a/deployments/stitch-frontend/src/queries/api.js +++ b/deployments/stitch-frontend/src/queries/api.js @@ -23,7 +23,7 @@ export async function getResource(id, fetcher) { } export async function getOGFields(fetcher) { - const url = `${config.apiBaseUrl}/oilgasfields/`; + const url = `${config.apiBaseUrl}/oil_gas_fields/`; const response = await fetcher(url); if (!response.ok) { throw new Error(`HTTP error! status: ${response.status}`); @@ -33,7 +33,7 @@ export async function getOGFields(fetcher) { } export async function getOGField(id, fetcher) { - const url = `${config.apiBaseUrl}/oilgasfields/${id}`; + const url = `${config.apiBaseUrl}/oil_gas_fields/${id}`; const response = await fetcher(url); if (!response.ok) { const error = new Error(`HTTP error! status: ${response.status}`); From 5bfbb243a4b6554c180dc836d260e1e989791cc8 Mon Sep 17 00:00:00 2001 From: Alex Axthelm Date: Wed, 4 Mar 2026 23:08:46 +0100 Subject: [PATCH 11/25] update path to kebab case --- deployments/api/src/stitch/api/routers/oil_gas_fields.py | 2 +- deployments/stitch-frontend/src/App.jsx | 4 ++-- deployments/stitch-frontend/src/queries/api.js | 4 ++-- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/deployments/api/src/stitch/api/routers/oil_gas_fields.py b/deployments/api/src/stitch/api/routers/oil_gas_fields.py index 7b2bbcb..209ac42 100644 --- a/deployments/api/src/stitch/api/routers/oil_gas_fields.py +++ b/deployments/api/src/stitch/api/routers/oil_gas_fields.py @@ -15,7 +15,7 @@ from stitch.ogsi.model.og_field import OilGasFieldBase # request model from stitch.ogsi.model import OGFieldView # response model -router = APIRouter(prefix="/oil_gas_fields", tags=["oil_gas_fields"]) +router = APIRouter(prefix="/oil-gas-fields", tags=["oil_gas_fields"]) @router.post("/", response_model=OGFieldView) diff --git a/deployments/stitch-frontend/src/App.jsx b/deployments/stitch-frontend/src/App.jsx index 22b8d17..098722d 100644 --- a/deployments/stitch-frontend/src/App.jsx +++ b/deployments/stitch-frontend/src/App.jsx @@ -8,8 +8,8 @@ function App() {
- - + + ); } diff --git a/deployments/stitch-frontend/src/queries/api.js b/deployments/stitch-frontend/src/queries/api.js index b79285e..459a1aa 100644 --- a/deployments/stitch-frontend/src/queries/api.js +++ b/deployments/stitch-frontend/src/queries/api.js @@ -23,7 +23,7 @@ export async function getResource(id, fetcher) { } export async function getOGFields(fetcher) { - const url = `${config.apiBaseUrl}/oil_gas_fields/`; + const url = `${config.apiBaseUrl}/oil-gas-fields/`; const response = await fetcher(url); if (!response.ok) { throw new Error(`HTTP error! status: ${response.status}`); @@ -33,7 +33,7 @@ export async function getOGFields(fetcher) { } export async function getOGField(id, fetcher) { - const url = `${config.apiBaseUrl}/oil_gas_fields/${id}`; + const url = `${config.apiBaseUrl}/oil-gas-fields/${id}`; const response = await fetcher(url); if (!response.ok) { const error = new Error(`HTTP error! status: ${response.status}`); From e904786a723523ce2e2c49e8063496b643dd61f2 Mon Sep 17 00:00:00 2001 From: Alex Axthelm Date: Wed, 4 Mar 2026 23:16:43 +0100 Subject: [PATCH 12/25] Trigger CI From 344e99bf323b1ec1e00ff4534971234e0f421eeb Mon Sep 17 00:00:00 2001 From: Alex Axthelm Date: Wed, 4 Mar 2026 23:18:44 +0100 Subject: [PATCH 13/25] style: ruff --- deployments/api/src/stitch/api/db/init_job.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/deployments/api/src/stitch/api/db/init_job.py b/deployments/api/src/stitch/api/db/init_job.py index fe3bc05..28b66e9 100644 --- a/deployments/api/src/stitch/api/db/init_job.py +++ b/deployments/api/src/stitch/api/db/init_job.py @@ -307,6 +307,7 @@ def create_seed_oil_gas_fields( return og_models + def seed_dev(engine) -> None: with Session(engine) as session: user_model = create_seed_user() @@ -334,6 +335,7 @@ def seed_dev(engine) -> None: session.commit() + def seed(engine, profile: SeedProfile | str) -> None: if profile == "dev": seed_dev(engine) From da251d8e4b8e4871827b8a91b2b513a33fcf7dc4 Mon Sep 17 00:00:00 2001 From: Alex Axthelm Date: Thu, 5 Mar 2026 10:35:23 +0100 Subject: [PATCH 14/25] add domain model to DB --- .../src/stitch/api/db/model/oil_gas_field.py | 66 ++++++++++++++++++- .../src/stitch/api/routers/oil_gas_fields.py | 15 +++-- 2 files changed, 75 insertions(+), 6 deletions(-) diff --git a/deployments/api/src/stitch/api/db/model/oil_gas_field.py b/deployments/api/src/stitch/api/db/model/oil_gas_field.py index 363ac90..d1e6b06 100644 --- a/deployments/api/src/stitch/api/db/model/oil_gas_field.py +++ b/deployments/api/src/stitch/api/db/model/oil_gas_field.py @@ -1,9 +1,10 @@ from __future__ import annotations -from typing import ClassVar +from typing import Any, ClassVar from pydantic import ConfigDict -from sqlalchemy import ForeignKey +from sqlalchemy import DateTime, Float, ForeignKey, Integer, String +from sqlalchemy.dialects.postgresql import JSONB from sqlalchemy.orm import Mapped, mapped_column from stitch.ogsi.model.og_field import OilGasFieldBase @@ -12,6 +13,11 @@ from .mixins import PayloadMixin, TimestampMixin, UserAuditMixin + +def _enum_value(v: Any) -> Any: + # Works for python Enums (return .value) but leaves strings/None unchanged. + return getattr(v, "value", v) + class OilGasFieldModel( TimestampMixin, UserAuditMixin, PayloadMixin[OilGasFieldBase], Base ): @@ -24,8 +30,64 @@ class OilGasFieldModel( ForeignKey("resources.id"), primary_key=True ) + # --- Domain columns (queryable / indexed later if needed) --- + name: Mapped[str | None] = mapped_column(String, nullable=True) + country: Mapped[str | None] = mapped_column(String(3), nullable=True) + latitude: Mapped[float | None] = mapped_column(Float, nullable=True) + longitude: Mapped[float | None] = mapped_column(Float, nullable=True) + last_updated: Mapped[Any | None] = mapped_column(DateTime(timezone=True), nullable=True) + name_local: Mapped[str | None] = mapped_column(String, nullable=True) + state_province: Mapped[str | None] = mapped_column(String, nullable=True) + region: Mapped[str | None] = mapped_column(String, nullable=True) + basin: Mapped[str | None] = mapped_column(String, nullable=True) + reservoir_formation: Mapped[str | None] = mapped_column(String, nullable=True) + discovery_year: Mapped[int | None] = mapped_column(Integer, nullable=True) + production_start_year: Mapped[int | None] = mapped_column(Integer, nullable=True) + fid_year: Mapped[int | None] = mapped_column(Integer, nullable=True) + + # Store enum-ish fields as strings (decouple DB from enum implementation details). + location_type: Mapped[str | None] = mapped_column(String, nullable=True) + production_conventionality: Mapped[str | None] = mapped_column(String, nullable=True) + primary_hydrocarbon_group: Mapped[str | None] = mapped_column(String, nullable=True) + field_status: Mapped[str | None] = mapped_column(String, nullable=True) + + # Owners/operators are structured; keep them as jsonb columns. + owners: Mapped[list[dict[str, Any]] | None] = mapped_column(JSONB, nullable=True) + operators: Mapped[list[dict[str, Any]] | None] = mapped_column(JSONB, nullable=True) + + # Raw/original input payload (includes extra fields not in OilGasFieldBase). + original_payload: Mapped[dict[str, Any]] = mapped_column( + JSONB, nullable=False, default=dict + ) + # Tell PayloadMixin what type to validate/serialize payload_type: ClassVar[type[OilGasFieldBase]] = OilGasFieldBase # Domain-agnostic: payloads don’t have to have `source` model_config: ClassVar[ConfigDict] = ConfigDict(from_attributes=True) + + def set_domain(self, domain: OilGasFieldBase) -> None: + """Populate columns from the validated domain model.""" + self.name = domain.name + self.country = domain.country + self.latitude = domain.latitude + self.longitude = domain.longitude + self.last_updated = domain.last_updated + self.name_local = domain.name_local + self.state_province = domain.state_province + self.region = domain.region + self.basin = domain.basin + self.reservoir_formation = domain.reservoir_formation + self.discovery_year = domain.discovery_year + self.production_start_year = domain.production_start_year + self.fid_year = domain.fid_year + self.location_type = _enum_value(domain.location_type) + self.production_conventionality = _enum_value(domain.production_conventionality) + self.primary_hydrocarbon_group = _enum_value(domain.primary_hydrocarbon_group) + self.field_status = _enum_value(domain.field_status) + self.owners = ( + [o.model_dump(mode="json") for o in domain.owners] if domain.owners else None + ) + self.operators = ( + [o.model_dump(mode="json") for o in domain.operators] if domain.operators else None + ) diff --git a/deployments/api/src/stitch/api/routers/oil_gas_fields.py b/deployments/api/src/stitch/api/routers/oil_gas_fields.py index 209ac42..3424308 100644 --- a/deployments/api/src/stitch/api/routers/oil_gas_fields.py +++ b/deployments/api/src/stitch/api/routers/oil_gas_fields.py @@ -1,5 +1,6 @@ from __future__ import annotations +from typing import Any from collections.abc import Sequence from fastapi import APIRouter, HTTPException @@ -20,18 +21,22 @@ @router.post("/", response_model=OGFieldView) async def create_oil_gas_field( - payload: OilGasFieldBase, + payload: dict[str, Any], uow: UnitOfWorkDep, user: CurrentUser, ): session: AsyncSession = uow.session + # Validate to the canonical package model (drops/ignores unknown fields), + # but keep raw input in original_payload for traceability. + domain = OilGasFieldBase.model_validate(payload) + # Create the generic resource first (label derived from OG name) created_res = await resource_actions.create( session=session, user=user, resource=resource_actions.CreateResource( - name=payload.name + name=domain.name ), # adjust import if CreateResource lives elsewhere in your branch ) @@ -40,12 +45,14 @@ async def create_oil_gas_field( created_by_id=user.id, last_updated_by_id=user.id, ) - og.payload = payload + og.original_payload = payload + og.payload = domain + og.set_domain(domain) session.add(og) await session.flush() # Package response type - return OGFieldView(id=og.resource_id, **payload.model_dump()) + return OGFieldView(id=og.resource_id, **domain.model_dump()) @router.get("/", response_model=Sequence[OGFieldView]) From 611352071b9d57cc6dce681b62af111a570b8a03 Mon Sep 17 00:00:00 2001 From: Alex Axthelm Date: Thu, 5 Mar 2026 10:39:51 +0100 Subject: [PATCH 15/25] require user for fetching --- deployments/api/src/stitch/api/routers/oil_gas_fields.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/deployments/api/src/stitch/api/routers/oil_gas_fields.py b/deployments/api/src/stitch/api/routers/oil_gas_fields.py index 3424308..30289d1 100644 --- a/deployments/api/src/stitch/api/routers/oil_gas_fields.py +++ b/deployments/api/src/stitch/api/routers/oil_gas_fields.py @@ -56,7 +56,7 @@ async def create_oil_gas_field( @router.get("/", response_model=Sequence[OGFieldView]) -async def list_oil_gas_fields(uow: UnitOfWorkDep): +async def list_oil_gas_fields(uow: UnitOfWorkDep, user: CurrentUser): session: AsyncSession = uow.session rows = (await session.execute(select(OilGasFieldModel))).scalars().all() @@ -68,7 +68,7 @@ async def list_oil_gas_fields(uow: UnitOfWorkDep): @router.get("/{id}", response_model=OGFieldView) -async def get_oil_gas_field(id: int, uow: UnitOfWorkDep): +async def get_oil_gas_field(id: int, uow: UnitOfWorkDep, user: CurrentUser): session: AsyncSession = uow.session row = await session.get(OilGasFieldModel, id) if row is None: From faaf88f19dddeb118b1dc31ee5d8c3184e065d05 Mon Sep 17 00:00:00 2001 From: Alex Axthelm Date: Thu, 5 Mar 2026 11:03:44 +0100 Subject: [PATCH 16/25] rename source field file --- .../api/db/model/{oil_gas_field.py => oil_gas_field_source.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename deployments/api/src/stitch/api/db/model/{oil_gas_field.py => oil_gas_field_source.py} (100%) diff --git a/deployments/api/src/stitch/api/db/model/oil_gas_field.py b/deployments/api/src/stitch/api/db/model/oil_gas_field_source.py similarity index 100% rename from deployments/api/src/stitch/api/db/model/oil_gas_field.py rename to deployments/api/src/stitch/api/db/model/oil_gas_field_source.py From 2c2de18c90fbb05d95e84c124b15f54ca3d5c985 Mon Sep 17 00:00:00 2001 From: Alex Axthelm Date: Thu, 5 Mar 2026 13:41:06 +0100 Subject: [PATCH 17/25] use OilGasFieldSourceModel --- .../api/src/stitch/api/db/model/__init__.py | 4 +- .../api/db/model/oil_gas_field_source.py | 98 +++++-------------- .../stitch/api/db/og_field_source_actions.py | 71 ++++++++++++++ .../src/stitch/api/routers/oil_gas_fields.py | 67 +++++-------- 4 files changed, 118 insertions(+), 122 deletions(-) create mode 100644 deployments/api/src/stitch/api/db/og_field_source_actions.py diff --git a/deployments/api/src/stitch/api/db/model/__init__.py b/deployments/api/src/stitch/api/db/model/__init__.py index e4bb764..d81c8af 100644 --- a/deployments/api/src/stitch/api/db/model/__init__.py +++ b/deployments/api/src/stitch/api/db/model/__init__.py @@ -1,5 +1,5 @@ from .common import Base as StitchBase -from .oil_gas_field import OilGasFieldModel +from .oil_gas_field_source import OilGasFieldSourceModel from .resource import MembershipStatus, MembershipModel, ResourceModel from .user import User as UserModel @@ -9,5 +9,5 @@ "ResourceModel", "StitchBase", "UserModel", - "OilGasFieldModel", + "OilGasFieldSourceModel", ] diff --git a/deployments/api/src/stitch/api/db/model/oil_gas_field_source.py b/deployments/api/src/stitch/api/db/model/oil_gas_field_source.py index d1e6b06..dca3435 100644 --- a/deployments/api/src/stitch/api/db/model/oil_gas_field_source.py +++ b/deployments/api/src/stitch/api/db/model/oil_gas_field_source.py @@ -1,93 +1,41 @@ from __future__ import annotations -from typing import Any, ClassVar - -from pydantic import ConfigDict -from sqlalchemy import DateTime, Float, ForeignKey, Integer, String -from sqlalchemy.dialects.postgresql import JSONB +from typing import Any + +from sqlalchemy import ( + Integer, + String, + JSON, + Column, + ForeignKey, +) from sqlalchemy.orm import Mapped, mapped_column from stitch.ogsi.model.og_field import OilGasFieldBase from .common import Base -from .mixins import PayloadMixin, TimestampMixin, UserAuditMixin - +from .mixins import TimestampMixin, UserAuditMixin -def _enum_value(v: Any) -> Any: - # Works for python Enums (return .value) but leaves strings/None unchanged. - return getattr(v, "value", v) +class OilGasFieldSourceModel(TimestampMixin, UserAuditMixin, Base): + """A single OG field source record (canonicalized), feedable into a Resource.""" -class OilGasFieldModel( - TimestampMixin, UserAuditMixin, PayloadMixin[OilGasFieldBase], Base -): - """Domain wrapper for an OG field, 1:1 with a Resource.""" + __tablename__ = "oil_gas_field_source" - __tablename__ = "oil_gas_fields" + id: Mapped[int] = mapped_column(Integer, primary_key=True) - # Use resource_id as both PK and FK: keeps ids consistent across /resources and /oil_gas_fields - resource_id: Mapped[int] = mapped_column( - ForeignKey("resources.id"), primary_key=True - ) + source: Mapped[str | None] = mapped_column(String, nullable=True) - # --- Domain columns (queryable / indexed later if needed) --- + # Flat domain columns for filtering, indexing, query, etc. name: Mapped[str | None] = mapped_column(String, nullable=True) - country: Mapped[str | None] = mapped_column(String(3), nullable=True) - latitude: Mapped[float | None] = mapped_column(Float, nullable=True) - longitude: Mapped[float | None] = mapped_column(Float, nullable=True) - last_updated: Mapped[Any | None] = mapped_column(DateTime(timezone=True), nullable=True) - name_local: Mapped[str | None] = mapped_column(String, nullable=True) - state_province: Mapped[str | None] = mapped_column(String, nullable=True) - region: Mapped[str | None] = mapped_column(String, nullable=True) + country: Mapped[str | None] = mapped_column(String, nullable=True) basin: Mapped[str | None] = mapped_column(String, nullable=True) - reservoir_formation: Mapped[str | None] = mapped_column(String, nullable=True) - discovery_year: Mapped[int | None] = mapped_column(Integer, nullable=True) - production_start_year: Mapped[int | None] = mapped_column(Integer, nullable=True) - fid_year: Mapped[int | None] = mapped_column(Integer, nullable=True) - - # Store enum-ish fields as strings (decouple DB from enum implementation details). - location_type: Mapped[str | None] = mapped_column(String, nullable=True) - production_conventionality: Mapped[str | None] = mapped_column(String, nullable=True) - primary_hydrocarbon_group: Mapped[str | None] = mapped_column(String, nullable=True) - field_status: Mapped[str | None] = mapped_column(String, nullable=True) - - # Owners/operators are structured; keep them as jsonb columns. - owners: Mapped[list[dict[str, Any]] | None] = mapped_column(JSONB, nullable=True) - operators: Mapped[list[dict[str, Any]] | None] = mapped_column(JSONB, nullable=True) - - # Raw/original input payload (includes extra fields not in OilGasFieldBase). - original_payload: Mapped[dict[str, Any]] = mapped_column( - JSONB, nullable=False, default=dict - ) - # Tell PayloadMixin what type to validate/serialize - payload_type: ClassVar[type[OilGasFieldBase]] = OilGasFieldBase + # full normalized domain payload + payload: Mapped[dict[str, Any]] = mapped_column(JSON, nullable=False) - # Domain-agnostic: payloads don’t have to have `source` - model_config: ClassVar[ConfigDict] = ConfigDict(from_attributes=True) + # original raw payload as given by client + original_payload: Mapped[dict[str, Any]] = mapped_column(JSON, nullable=False) - def set_domain(self, domain: OilGasFieldBase) -> None: - """Populate columns from the validated domain model.""" - self.name = domain.name - self.country = domain.country - self.latitude = domain.latitude - self.longitude = domain.longitude - self.last_updated = domain.last_updated - self.name_local = domain.name_local - self.state_province = domain.state_province - self.region = domain.region - self.basin = domain.basin - self.reservoir_formation = domain.reservoir_formation - self.discovery_year = domain.discovery_year - self.production_start_year = domain.production_start_year - self.fid_year = domain.fid_year - self.location_type = _enum_value(domain.location_type) - self.production_conventionality = _enum_value(domain.production_conventionality) - self.primary_hydrocarbon_group = _enum_value(domain.primary_hydrocarbon_group) - self.field_status = _enum_value(domain.field_status) - self.owners = ( - [o.model_dump(mode="json") for o in domain.owners] if domain.owners else None - ) - self.operators = ( - [o.model_dump(mode="json") for o in domain.operators] if domain.operators else None - ) + # optionally track an external ref if useful + source_ref: Mapped[str | None] = mapped_column(String, nullable=True) diff --git a/deployments/api/src/stitch/api/db/og_field_source_actions.py b/deployments/api/src/stitch/api/db/og_field_source_actions.py new file mode 100644 index 0000000..b4ff238 --- /dev/null +++ b/deployments/api/src/stitch/api/db/og_field_source_actions.py @@ -0,0 +1,71 @@ +from typing import Mapping + +from fastapi import HTTPException +from starlette.status import HTTP_404_NOT_FOUND +from sqlalchemy import select + +from stitch.ogsi.model.og_field import OilGasFieldBase + +from .model import OilGasFieldSourceModel, ResourceModel, MembershipModel + + +async def create_source( + session, + raw_payload: dict[str, object], + *, + source_system: str | None = None, + source_ref: str | None = None, +) -> OilGasFieldSourceModel: + """Validate raw JSON into domain model, persist canonical + original.""" + + # domain validation (pydantic) + domain: OilGasFieldBase = OilGasFieldBase.model_validate(raw_payload) + + model = OilGasFieldSourceModel( + source_system=source_system, + source_ref=source_ref, + name=domain.name, + country=domain.country, + basin=domain.basin, + payload=domain.model_dump(), + original_payload=raw_payload, + ) + session.add(model) + await session.flush() + return model + + +async def attach_to_resource( + session, + resource_id: int, + source_row: OilGasFieldSourceModel, + created_by, +): + """Link an OG field source to a resource via membership.""" + session.add( + MembershipModel.create( + created_by=created_by, + resource=session.get(ResourceModel, resource_id), + source="og_field", + source_pk=source_row.id, + ) + ) + await session.flush() + + +async def get_source(session, id: int) -> OilGasFieldSourceModel: + model = await session.get(OilGasFieldSourceModel, id) + if model is None: + raise HTTPException( + status_code=HTTP_404_NOT_FOUND, detail=f"No OG field source with id `{id}`" + ) + return model + +async def list_og_resources(session): + stmt = ( + select(ResourceModel) + .join(MembershipModel, MembershipModel.resource_id == ResourceModel.id) + .where(MembershipModel.source == "og_field") + .distinct() + ) + return (await session.scalars(stmt)).all() diff --git a/deployments/api/src/stitch/api/routers/oil_gas_fields.py b/deployments/api/src/stitch/api/routers/oil_gas_fields.py index 30289d1..de225f8 100644 --- a/deployments/api/src/stitch/api/routers/oil_gas_fields.py +++ b/deployments/api/src/stitch/api/routers/oil_gas_fields.py @@ -9,9 +9,11 @@ from starlette.status import HTTP_404_NOT_FOUND from stitch.api.auth import CurrentUser -from stitch.api.db import resource_actions +from stitch.api.db import resource_actions, og_field_source_actions from stitch.api.db.config import UnitOfWorkDep -from stitch.api.db.model import OilGasFieldModel +from stitch.api.db.model import OilGasFieldSourceModel +from stitch.api.db.og_field_source_actions import create_source, attach_to_resource +from stitch.api.resources.entities import CreateResource, Resource from stitch.ogsi.model.og_field import OilGasFieldBase # request model from stitch.ogsi.model import OGFieldView # response model @@ -19,63 +21,38 @@ router = APIRouter(prefix="/oil-gas-fields", tags=["oil_gas_fields"]) -@router.post("/", response_model=OGFieldView) +@router.post("/", response_model=Resource) async def create_oil_gas_field( - payload: dict[str, Any], + raw_body: dict[str, object], uow: UnitOfWorkDep, user: CurrentUser, ): - session: AsyncSession = uow.session + session = uow.session - # Validate to the canonical package model (drops/ignores unknown fields), - # but keep raw input in original_payload for traceability. - domain = OilGasFieldBase.model_validate(payload) - - # Create the generic resource first (label derived from OG name) - created_res = await resource_actions.create( + # 1) create a generic resource + resource = await resource_actions.create( session=session, user=user, - resource=resource_actions.CreateResource( - name=domain.name - ), # adjust import if CreateResource lives elsewhere in your branch + resource=CreateResource(name=raw_body.get("name")) ) - og = OilGasFieldModel( - resource_id=created_res.id, - created_by_id=user.id, - last_updated_by_id=user.id, + # 2) create canonical domain source + src = await create_source( + session=session, + raw_payload=raw_body, + source_system=raw_body.get("source_system"), ) - og.original_payload = payload - og.payload = domain - og.set_domain(domain) - session.add(og) - await session.flush() - # Package response type - return OGFieldView(id=og.resource_id, **domain.model_dump()) + # 3) attach it via membership + await attach_to_resource(session, resource.id, src, user) + return resource -@router.get("/", response_model=Sequence[OGFieldView]) +@router.get("/", response_model=list[Resource]) async def list_oil_gas_fields(uow: UnitOfWorkDep, user: CurrentUser): - session: AsyncSession = uow.session - rows = (await session.execute(select(OilGasFieldModel))).scalars().all() + return await og_field_source_actions.list_og_resources(session=uow.session) - out: list[OGFieldView] = [] - for row in rows: - p = row.payload - out.append(OGFieldView(id=row.resource_id, **p.model_dump())) - return out - -@router.get("/{id}", response_model=OGFieldView) +@router.get("/{id}", response_model=Resource) async def get_oil_gas_field(id: int, uow: UnitOfWorkDep, user: CurrentUser): - session: AsyncSession = uow.session - row = await session.get(OilGasFieldModel, id) - if row is None: - raise HTTPException( - status_code=HTTP_404_NOT_FOUND, - detail=f"No OilGasField with id `{id}` found.", - ) - - p = row.payload - return OGFieldView(id=row.resource_id, **p.model_dump()) + return await resource_actions.get(session=uow.session, id=id) From 87b30fecde865b8a2876920ad2b855998d1903b4 Mon Sep 17 00:00:00 2001 From: Alex Axthelm Date: Thu, 5 Mar 2026 13:57:41 +0100 Subject: [PATCH 18/25] seed ogfield source table --- deployments/api/src/stitch/api/db/init_job.py | 56 ++++++++++++------- 1 file changed, 35 insertions(+), 21 deletions(-) diff --git a/deployments/api/src/stitch/api/db/init_job.py b/deployments/api/src/stitch/api/db/init_job.py index 28b66e9..d7d25f5 100644 --- a/deployments/api/src/stitch/api/db/init_job.py +++ b/deployments/api/src/stitch/api/db/init_job.py @@ -5,6 +5,7 @@ import time from enum import Enum from dataclasses import dataclass +from typing import Any from sqlalchemy import create_engine, inspect, text from sqlalchemy.exc import OperationalError @@ -14,7 +15,7 @@ ResourceModel, StitchBase, UserModel, - OilGasFieldModel, + OilGasFieldSourceModel, ) from stitch.api.entities import ( User as UserEntity, @@ -274,35 +275,48 @@ def create_seed_resources(user: UserEntity) -> list[ResourceModel]: return resources -def create_seed_oil_gas_fields( +def create_seed_oil_gas_source_fields( user: UserEntity, resources: list[ResourceModel], -) -> list[OilGasFieldModel]: +) -> list[OilGasFieldSourceModel]: """Create example OilGasField rows linked 1:1 with seeded resources.""" - # Construct payloads using the package model - payloads = [ - OilGasFieldBase( - name="Permian Alpha", - country="USA", - basin="Permian", - ), - OilGasFieldBase( - name="North Sea Bravo", - country="GBR", - basin="North Sea", - ), + raw_payloads: list[dict[str, Any]] = [ + # pretend this came from some upstream system (GEM/WM/etc) + { + "name": "Permian Alpha", + "country": "USA", + "basin": "Permian", + # extra keys demonstrate why we keep original_payload + "upstream_id": "seed-gem-0001", + "notes": "seed example", + }, + { + "name": "North Sea Bravo", + "country": "GBR", + "basin": "North Sea", + "upstream_id": "seed-wm-0002", + "notes": "seed example", + }, ] - og_models: list[OilGasFieldModel] = [] + og_models: list[OilGasFieldSourceModel] = [] - for resource, payload in zip(resources, payloads): - model = OilGasFieldModel( - resource_id=resource.id, + for resource, raw in zip(resources, raw_payloads): + domain = OilGasFieldBase.model_validate(raw) + model = OilGasFieldSourceModel( created_by_id=user.id, last_updated_by_id=user.id, ) - model.payload = payload + # Raw input (includes extra fields not in OilGasFieldBase) + model.original_payload = raw + # Canonical validated payload + model.payload = raw + model.name = domain.name + model.country = domain.country + model.basin = domain.basin + model.source = "dev-seed" + # Populate domain columns for queryability og_models.append(model) return og_models @@ -330,7 +344,7 @@ def seed_dev(engine) -> None: session.flush() # # Add sample OilGasField rows for the first two resources only - og_fields = create_seed_oil_gas_fields(user_entity, resources) + og_fields = create_seed_oil_gas_source_fields(user_entity, resources) session.add_all(og_fields) session.commit() From cbe3f12d3beebb722343ddb01841f531b096dd0c Mon Sep 17 00:00:00 2001 From: Alex Axthelm Date: Thu, 5 Mar 2026 14:16:00 +0100 Subject: [PATCH 19/25] wip: fetching resources, no constituents --- deployments/api/src/stitch/api/db/init_job.py | 20 +++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/deployments/api/src/stitch/api/db/init_job.py b/deployments/api/src/stitch/api/db/init_job.py index d7d25f5..d967fd3 100644 --- a/deployments/api/src/stitch/api/db/init_job.py +++ b/deployments/api/src/stitch/api/db/init_job.py @@ -13,6 +13,7 @@ from stitch.api.db.model import ( ResourceModel, + MembershipModel, StitchBase, UserModel, OilGasFieldSourceModel, @@ -275,6 +276,20 @@ def create_seed_resources(user: UserEntity) -> list[ResourceModel]: return resources +def create_seed_memberships( + user: UserEntity, + resources: list[ResourceModel], + sources: list[OilGasFieldSourceModel] + ) -> list[MembershipModel]: + memberships = [ + MembershipModel.create(user, resources[0], "og_field", 1), + MembershipModel.create(user, resources[1], "og_field", 2), + ] + for i, mem in enumerate(memberships, start=1): + mem.id = i + return memberships + + def create_seed_oil_gas_source_fields( user: UserEntity, resources: list[ResourceModel], @@ -347,6 +362,11 @@ def seed_dev(engine) -> None: og_fields = create_seed_oil_gas_source_fields(user_entity, resources) session.add_all(og_fields) + memberships = create_seed_memberships( + user_entity, resources, og_fields + ) + session.add_all(memberships) + session.commit() From ac8781bf74042a44bf08adf1da50d2ec0db0bf4a Mon Sep 17 00:00:00 2001 From: Alex Axthelm Date: Thu, 5 Mar 2026 14:17:24 +0100 Subject: [PATCH 20/25] style: format --- deployments/api/src/stitch/api/db/init_job.py | 8 +++----- .../api/src/stitch/api/db/og_field_source_actions.py | 1 + deployments/api/src/stitch/api/routers/oil_gas_fields.py | 5 ++--- 3 files changed, 6 insertions(+), 8 deletions(-) diff --git a/deployments/api/src/stitch/api/db/init_job.py b/deployments/api/src/stitch/api/db/init_job.py index d967fd3..be72a6d 100644 --- a/deployments/api/src/stitch/api/db/init_job.py +++ b/deployments/api/src/stitch/api/db/init_job.py @@ -279,8 +279,8 @@ def create_seed_resources(user: UserEntity) -> list[ResourceModel]: def create_seed_memberships( user: UserEntity, resources: list[ResourceModel], - sources: list[OilGasFieldSourceModel] - ) -> list[MembershipModel]: + sources: list[OilGasFieldSourceModel], +) -> list[MembershipModel]: memberships = [ MembershipModel.create(user, resources[0], "og_field", 1), MembershipModel.create(user, resources[1], "og_field", 2), @@ -362,9 +362,7 @@ def seed_dev(engine) -> None: og_fields = create_seed_oil_gas_source_fields(user_entity, resources) session.add_all(og_fields) - memberships = create_seed_memberships( - user_entity, resources, og_fields - ) + memberships = create_seed_memberships(user_entity, resources, og_fields) session.add_all(memberships) session.commit() diff --git a/deployments/api/src/stitch/api/db/og_field_source_actions.py b/deployments/api/src/stitch/api/db/og_field_source_actions.py index b4ff238..f1d6614 100644 --- a/deployments/api/src/stitch/api/db/og_field_source_actions.py +++ b/deployments/api/src/stitch/api/db/og_field_source_actions.py @@ -61,6 +61,7 @@ async def get_source(session, id: int) -> OilGasFieldSourceModel: ) return model + async def list_og_resources(session): stmt = ( select(ResourceModel) diff --git a/deployments/api/src/stitch/api/routers/oil_gas_fields.py b/deployments/api/src/stitch/api/routers/oil_gas_fields.py index de225f8..b060255 100644 --- a/deployments/api/src/stitch/api/routers/oil_gas_fields.py +++ b/deployments/api/src/stitch/api/routers/oil_gas_fields.py @@ -31,9 +31,7 @@ async def create_oil_gas_field( # 1) create a generic resource resource = await resource_actions.create( - session=session, - user=user, - resource=CreateResource(name=raw_body.get("name")) + session=session, user=user, resource=CreateResource(name=raw_body.get("name")) ) # 2) create canonical domain source @@ -48,6 +46,7 @@ async def create_oil_gas_field( return resource + @router.get("/", response_model=list[Resource]) async def list_oil_gas_fields(uow: UnitOfWorkDep, user: CurrentUser): return await og_field_source_actions.list_og_resources(session=uow.session) From da6b8a047c263769feeb9cef95617330c4d36ab9 Mon Sep 17 00:00:00 2001 From: Alex Axthelm Date: Thu, 5 Mar 2026 16:17:57 +0100 Subject: [PATCH 21/25] wip; adding in og_field memberships --- .../api/src/stitch/api/db/model/resource.py | 18 ++- .../api/src/stitch/api/db/model/sources.py | 141 ++++++++++++++++++ .../stitch/api/db/og_field_source_actions.py | 13 +- .../api/src/stitch/api/db/resource_actions.py | 112 ++++++++++++-- deployments/api/src/stitch/api/entities.py | 90 +++++++++++ .../api/src/stitch/api/resources/entities.py | 19 --- .../src/stitch/api/routers/oil_gas_fields.py | 13 +- 7 files changed, 359 insertions(+), 47 deletions(-) create mode 100644 deployments/api/src/stitch/api/db/model/sources.py delete mode 100644 deployments/api/src/stitch/api/resources/entities.py diff --git a/deployments/api/src/stitch/api/db/model/resource.py b/deployments/api/src/stitch/api/db/model/resource.py index c98fd49..4a04f01 100644 --- a/deployments/api/src/stitch/api/db/model/resource.py +++ b/deployments/api/src/stitch/api/db/model/resource.py @@ -17,7 +17,6 @@ from .mixins import TimestampMixin, UserAuditMixin from .types import PORTABLE_BIGINT - class MembershipStatus(StrEnum): ACTIVE = "ACTIVE" INACTIVE = "INACTIVE" @@ -89,6 +88,23 @@ class ResourceModel(TimestampMixin, UserAuditMixin, Base): # and configure the appropriate SQL statement to load the membership objects memberships: Mapped[list[MembershipModel]] = relationship() + async def get_source_data(self, session: AsyncSession): + pks_by_src: dict[SourceKey, set[int]] = defaultdict(set) + for mem in self.memberships: + if mem.status == MembershipStatus.ACTIVE: + pks_by_src[mem.source].add(mem.source_pk) + + results: dict[SourceKey, dict[IdType, SourceModel]] = defaultdict(dict) + for src, pks in pks_by_src.items(): + model_cls = SOURCE_TABLES.get(src) + if model_cls is None: + continue + stmt = select(model_cls).where(model_cls.id.in_(pks)) + for src_model in await session.scalars(stmt): + results[src][src_model.id] = src_model + + return SourceModelData(**results) + async def get_root(self, session: AsyncSession): root = await session.scalar(self.__class__._root_select(self.id)) if root is None: diff --git a/deployments/api/src/stitch/api/db/model/sources.py b/deployments/api/src/stitch/api/db/model/sources.py new file mode 100644 index 0000000..6b8e5af --- /dev/null +++ b/deployments/api/src/stitch/api/db/model/sources.py @@ -0,0 +1,141 @@ +from typing_extensions import Self + +from collections.abc import Mapping, MutableMapping +from typing import Final, Generic, TypeVar, TypedDict, get_args, get_origin +from pydantic import BaseModel +from sqlalchemy import CheckConstraint, inspect +from sqlalchemy.orm import Mapped, mapped_column +from .common import Base +from .types import PORTABLE_BIGINT, StitchJson +from stitch.api.entities import ( + CCReservoirsSource, + GemSource, + IdType, + RMIManualSource, + SourceKey, + WMData, + GemData, + RMIManualData, + CCReservoirsData, + WMSource, +) + + +def float_constraint( + colname: str, min_: float | None = None, max_: float | None = None +) -> CheckConstraint: + min_str = f"{colname} >= {min_}" if min_ is not None else None + max_str = f"{colname} <= {max_}" if max_ is not None else None + expr = " AND ".join(filter(None, (min_str, max_str))) + return CheckConstraint(expr) + + +def lat_constraints(colname: str): + return float_constraint(colname, -90, 90) + + +def lon_constraints(colname: str): + return float_constraint(colname, -180, 180) + + +TModelIn = TypeVar("TModelIn", bound=BaseModel) +TModelOut = TypeVar("TModelOut", bound=BaseModel) + + +class SourceBase(Base, Generic[TModelIn, TModelOut]): + __abstract__ = True + __entity_class_in__: type[TModelIn] + __entity_class_out__: type[TModelOut] + + id: Mapped[int] = mapped_column( + PORTABLE_BIGINT, primary_key=True, autoincrement=True + ) + + def __init_subclass__(cls, **kwargs) -> None: + super().__init_subclass__(**kwargs) + for base in getattr(cls, "__orig_bases__", ()): + if get_origin(base) is SourceBase: + args = get_args(base) + if len(args) >= 2: + if isinstance(args[0], type): + cls.__entity_class_in__ = args[0] + if isinstance(args[1], type): + cls.__entity_class_out__ = args[1] + break + + def as_entity(self): + return self.__entity_class_out__.model_validate(self) + + @classmethod + def from_entity(cls, entity: TModelIn) -> Self: + mapper = inspect(cls) + column_keys = {col.key for col in mapper.columns} + filtered = {k: v for k, v in entity.model_dump().items() if k in column_keys} + return cls(**filtered) + + +class GemSourceModel(SourceBase[GemData, GemSource]): + __tablename__ = "gem_sources" + + name: Mapped[str] + country: Mapped[str] + lat: Mapped[float] = mapped_column(lat_constraints("lat")) + lon: Mapped[float] = mapped_column(lon_constraints("lon")) + + +class WMSourceModel(SourceBase[WMData, WMSource]): + __tablename__ = "wm_sources" + + field_name: Mapped[str] + field_country: Mapped[str] + production: Mapped[float] + + +class RMIManualSourceModel(SourceBase[RMIManualData, RMIManualSource]): + __tablename__ = "rmi_manual_sources" + + name_override: Mapped[str | None] + gwp: Mapped[float | None] + gor: Mapped[float | None | None] = mapped_column( + float_constraint("gor", 0, 1), nullable=True + ) + country: Mapped[str | None] + latitude: Mapped[float | None] = mapped_column( + lat_constraints("latitude"), nullable=True + ) + longitude: Mapped[float | None] = mapped_column( + lon_constraints("longitude"), nullable=True + ) + + +class CCReservoirsSourceModel(SourceBase[CCReservoirsData, CCReservoirsSource]): + __tablename__ = "cc_reservoirs_sources" + + name: Mapped[str] + basin: Mapped[str] + depth: Mapped[float] + geofence: Mapped[list[tuple[float, float]]] = mapped_column(StitchJson()) + + +SourceModel = ( + GemSourceModel | WMSourceModel | RMIManualSourceModel | CCReservoirsSourceModel +) +SourceModelCls = type[SourceModel] + +SOURCE_TABLES: Final[Mapping[SourceKey, SourceModelCls]] = { + "gem": GemSourceModel, + "wm": WMSourceModel, + "rmi": RMIManualSourceModel, + "cc": CCReservoirsSourceModel, +} + + +class SourceModelData(TypedDict, total=False): + gem: MutableMapping[IdType, GemSourceModel] + wm: MutableMapping[IdType, WMSourceModel] + cc: MutableMapping[IdType, CCReservoirsSourceModel] + rmi: MutableMapping[IdType, RMIManualSourceModel] + + +def empty_source_model_data(): + return SourceModelData(gem={}, wm={}, cc={}, rmi={}) diff --git a/deployments/api/src/stitch/api/db/og_field_source_actions.py b/deployments/api/src/stitch/api/db/og_field_source_actions.py index f1d6614..81335f1 100644 --- a/deployments/api/src/stitch/api/db/og_field_source_actions.py +++ b/deployments/api/src/stitch/api/db/og_field_source_actions.py @@ -1,12 +1,16 @@ -from typing import Mapping +import asyncio + +from functools import partial from fastapi import HTTPException from starlette.status import HTTP_404_NOT_FOUND from sqlalchemy import select +from sqlalchemy.orm import selectinload from stitch.ogsi.model.og_field import OilGasFieldBase from .model import OilGasFieldSourceModel, ResourceModel, MembershipModel +from .resource_actions import resource_model_to_entity async def create_source( @@ -61,12 +65,15 @@ async def get_source(session, id: int) -> OilGasFieldSourceModel: ) return model - async def list_og_resources(session): stmt = ( select(ResourceModel) + .where(ResourceModel.repointed_id.is_(None)) .join(MembershipModel, MembershipModel.resource_id == ResourceModel.id) .where(MembershipModel.source == "og_field") + .options(selectinload(ResourceModel.memberships)) .distinct() ) - return (await session.scalars(stmt)).all() + models = (await session.scalars(stmt)).all() + fn = partial(resource_model_to_entity, session) + return await asyncio.gather(*[fn(m) for m in models]) diff --git a/deployments/api/src/stitch/api/db/resource_actions.py b/deployments/api/src/stitch/api/db/resource_actions.py index 1ad9358..8a1a34d 100644 --- a/deployments/api/src/stitch/api/db/resource_actions.py +++ b/deployments/api/src/stitch/api/db/resource_actions.py @@ -1,38 +1,90 @@ import asyncio -from collections.abc import Sequence +from collections import defaultdict +from collections.abc import Mapping, Sequence from functools import partial from fastapi import HTTPException from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import selectinload from starlette.status import HTTP_404_NOT_FOUND -from stitch.api.resources.entities import CreateResource, Resource + +from stitch.api.db.model.sources import SOURCE_TABLES, SourceModel from stitch.api.auth import CurrentUser +from stitch.api.entities import ( + CreateResource, + CreateResourceSourceData, + CreateSourceData, + Resource, + SourceData, + SourceKey, +) from .model import ( + MembershipModel, ResourceModel, ) +from stitch.ogsi.model.og_field import OilGasFieldBase + + +async def get_or_create_source_models( + session: AsyncSession, + data: CreateResourceSourceData, +) -> Mapping[SourceKey, Sequence[SourceModel]]: + result: dict[SourceKey, list[SourceModel]] = defaultdict(list) + for key, model_cls in SOURCE_TABLES.items(): + for item in data.get(key): + if isinstance(item, int): + src_model = await session.get(model_cls, item) + if src_model is None: + continue + result[key].append(src_model) + else: + result[key].append(model_cls.from_entity(item)) + return result + + +def resource_model_to_empty_entity(model: ResourceModel): + return Resource( + id=model.id, + name=model.name, + country=model.country, + source_data=SourceData(), + constituents=[], + created=model.created, + updated=model.updated, + ) async def resource_model_to_entity( session: AsyncSession, model: ResourceModel ) -> Resource: - # Domain-agnostic: constituents are just ids, and there is no source_data. + source_model_data = await model.get_source_data(session) + source_data = SourceData.model_validate(source_model_data) constituent_models = await ResourceModel.get_constituents_by_root_id( session, model.id ) - constituent_ids = [m.id for m in constituent_models if m.id != model.id] + constituents = [ + resource_model_to_empty_entity(cm) + for cm in constituent_models + if cm.id != model.id + ] return Resource( id=model.id, name=model.name, - repointed_to=model.repointed_id, - constituents=constituent_ids, - created=str(model.created) if getattr(model, "created", None) else None, - updated=str(model.updated) if getattr(model, "updated", None) else None, + country=model.country, + source_data=source_data, + constituents=constituents, + created=model.created, + updated=model.updated, ) async def get_all(session: AsyncSession) -> Sequence[Resource]: - stmt = select(ResourceModel).where(ResourceModel.repointed_id.is_(None)) + stmt = ( + select(ResourceModel) + .where(ResourceModel.repointed_id.is_(None)) + .options(selectinload(ResourceModel.memberships)) + ) models = (await session.scalars(stmt)).all() fn = partial(resource_model_to_entity, session) return await asyncio.gather(*[fn(m) for m in models]) @@ -44,18 +96,52 @@ async def get(session: AsyncSession, id: int): raise HTTPException( status_code=HTTP_404_NOT_FOUND, detail=f"No Resource with id `{id}` found." ) + await session.refresh(model, ["memberships"]) return await resource_model_to_entity(session, model) async def create(session: AsyncSession, user: CurrentUser, resource: CreateResource): """ - Domain-agnostic create: - - create the resource row only + Here we create a resource either from new source data or existing source data. It's also possible + to create an empty resource with no reference to source data. + + - create the resource + - create the sources + - create membership """ model = ResourceModel.create( - created_by=user, - name=resource.name, + created_by=user, name=resource.name, country=resource.country ) session.add(model) + if resource.source_data: + src_model_groups = await get_or_create_source_models( + session, resource.source_data + ) + for src_key, src_models in src_model_groups.items(): + session.add_all(src_models) + await session.flush() + for src_model in src_models: + session.add( + MembershipModel.create( + created_by=user, + resource=model, + source=src_key, + source_pk=src_model.id, + ) + ) await session.flush() + await session.refresh(model, ["memberships"]) return await resource_model_to_entity(session, model) + + +async def create_source_data(session: AsyncSession, data: CreateSourceData): + """ + For bulk inserting data into source tables. + """ + og_fields = tuple(OilGasFieldBase.from_entity(og_field) for og_field in data.og_field) + + session.add_all(og_fields) + await session.flush() + return SourceData( + og_field={g.id: g.as_entity() for g in og_fields}, + ) diff --git a/deployments/api/src/stitch/api/entities.py b/deployments/api/src/stitch/api/entities.py index fecbfab..f8451ce 100644 --- a/deployments/api/src/stitch/api/entities.py +++ b/deployments/api/src/stitch/api/entities.py @@ -1,5 +1,92 @@ +from collections.abc import Sequence +from datetime import datetime +from typing import ( + Generic, + Mapping, + Protocol, + TypeVar, + runtime_checkable, +) +from uuid import UUID from pydantic import BaseModel, EmailStr, Field +from stitch.ogsi.model.types import OGSISrcKey +from stitch.ogsi.model.og_field import OilGasFieldBase + +IdType = int | str | UUID + + +@runtime_checkable +class HasId(Protocol): + @property + def id(self) -> IdType: ... + + +TSourceKey = TypeVar("TSourceKey", bound=OGSISrcKey) + + +class Timestamped(BaseModel): + created: datetime = Field(default_factory=datetime.now) + updated: datetime = Field(default_factory=datetime.now) + + +class Identified(BaseModel): + id: IdType + + +class SourceBase(BaseModel, Generic[TSourceKey]): + source: TSourceKey + id: IdType + + +class SourceRef(BaseModel): + source: OGSISrcKey + id: int + + +# The sources will come in and be initially stored in a raw table. +# That raw table will be an append-only table. +# We'll translate that data into one of the below structures, so each source will have a `UUID` or similar that +# references their id in the "raw" table. +# When pulling into the internal "sources" table, each will get a new unique id which is what the memberships will reference + + +class SourceData(BaseModel): + og_field: Mapping[IdType, OilGasFieldBase] = Field(default_factory=dict) + + +class CreateSourceData(BaseModel): + og_field: Sequence[OilGasFieldBase] = Field(default_factory=list) + + +class CreateResourceSourceData(BaseModel): + """Allows for creating source data or referencing existing sources by ID. + + It can be used in isolation to insert source data or used with a new/existing resource to automatically add + memberships to the resource. + """ + + og_field: Sequence[OilGasFieldBase | int] = Field(default_factory=list) + + def get(self, key: OGSISrcKey): + raise ValueError(f"Unknown source key: {key}") + + +class ResourceBase(BaseModel): + name: str | None = Field(default=None) + country: str | None = Field(default=None) + repointed_to: "Resource | None" = Field(default=None) + + +class Resource(ResourceBase, Timestamped): + id: int + source_data: SourceData + constituents: Sequence["Resource"] + + +class CreateResource(ResourceBase): + source_data: CreateResourceSourceData | None + class User(BaseModel): id: int = Field(...) @@ -7,3 +94,6 @@ class User(BaseModel): role: str | None = None email: EmailStr name: str + + +class SourceSelectionLogic(BaseModel): ... diff --git a/deployments/api/src/stitch/api/resources/entities.py b/deployments/api/src/stitch/api/resources/entities.py deleted file mode 100644 index e1f1471..0000000 --- a/deployments/api/src/stitch/api/resources/entities.py +++ /dev/null @@ -1,19 +0,0 @@ -from __future__ import annotations - -from pydantic import BaseModel, ConfigDict, Field - - -class CreateResource(BaseModel): - """Domain-agnostic create model.""" - - model_config = ConfigDict(extra="forbid") - name: str | None = None - - -class Resource(BaseModel): - """Domain-agnostic read model.""" - - id: int - name: str | None = None - repointed_to: int | None = None - constituents: frozenset[int] = Field(default_factory=frozenset) diff --git a/deployments/api/src/stitch/api/routers/oil_gas_fields.py b/deployments/api/src/stitch/api/routers/oil_gas_fields.py index b060255..bb8496a 100644 --- a/deployments/api/src/stitch/api/routers/oil_gas_fields.py +++ b/deployments/api/src/stitch/api/routers/oil_gas_fields.py @@ -1,22 +1,13 @@ from __future__ import annotations -from typing import Any -from collections.abc import Sequence -from fastapi import APIRouter, HTTPException -from sqlalchemy import select -from sqlalchemy.ext.asyncio import AsyncSession -from starlette.status import HTTP_404_NOT_FOUND +from fastapi import APIRouter from stitch.api.auth import CurrentUser from stitch.api.db import resource_actions, og_field_source_actions from stitch.api.db.config import UnitOfWorkDep -from stitch.api.db.model import OilGasFieldSourceModel from stitch.api.db.og_field_source_actions import create_source, attach_to_resource -from stitch.api.resources.entities import CreateResource, Resource - -from stitch.ogsi.model.og_field import OilGasFieldBase # request model -from stitch.ogsi.model import OGFieldView # response model +from stitch.api.entities import CreateResource, Resource router = APIRouter(prefix="/oil-gas-fields", tags=["oil_gas_fields"]) From 2179929325320cca0681adebf150dc0b671a7e12 Mon Sep 17 00:00:00 2001 From: Alex Axthelm Date: Thu, 5 Mar 2026 18:22:42 +0100 Subject: [PATCH 22/25] wip: restore api-level Source info --- .../api/db/model/oil_gas_field_source.py | 4 -- .../api/src/stitch/api/db/model/resource.py | 10 ++- .../api/src/stitch/api/db/model/sources.py | 72 +++---------------- deployments/api/src/stitch/api/entities.py | 7 +- .../api/src/stitch/api/ogsi/entities.py | 18 ----- 5 files changed, 24 insertions(+), 87 deletions(-) delete mode 100644 deployments/api/src/stitch/api/ogsi/entities.py diff --git a/deployments/api/src/stitch/api/db/model/oil_gas_field_source.py b/deployments/api/src/stitch/api/db/model/oil_gas_field_source.py index dca3435..b856e7c 100644 --- a/deployments/api/src/stitch/api/db/model/oil_gas_field_source.py +++ b/deployments/api/src/stitch/api/db/model/oil_gas_field_source.py @@ -6,13 +6,9 @@ Integer, String, JSON, - Column, - ForeignKey, ) from sqlalchemy.orm import Mapped, mapped_column -from stitch.ogsi.model.og_field import OilGasFieldBase - from .common import Base from .mixins import TimestampMixin, UserAuditMixin diff --git a/deployments/api/src/stitch/api/db/model/resource.py b/deployments/api/src/stitch/api/db/model/resource.py index 4a04f01..bbdf443 100644 --- a/deployments/api/src/stitch/api/db/model/resource.py +++ b/deployments/api/src/stitch/api/db/model/resource.py @@ -1,3 +1,4 @@ +from collections import defaultdict from enum import StrEnum from sqlalchemy import ( ForeignKey, @@ -11,6 +12,13 @@ from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import Mapped, mapped_column, relationship +from .sources import ( + SOURCE_TABLES, + SourceKey, + SourceModel, + SourceModelData, +) + from stitch.models.types import IdType from stitch.api.entities import User as UserEntity from .common import Base @@ -52,7 +60,7 @@ def create( model = cls( resource_id=resource.id, source=source, - source_pk=int(source_pk), + source_pk=str(source_pk), status=status, created_by_id=created_by.id, last_updated_by_id=created_by.id, diff --git a/deployments/api/src/stitch/api/db/model/sources.py b/deployments/api/src/stitch/api/db/model/sources.py index 6b8e5af..f8dbb5e 100644 --- a/deployments/api/src/stitch/api/db/model/sources.py +++ b/deployments/api/src/stitch/api/db/model/sources.py @@ -6,20 +6,13 @@ from sqlalchemy import CheckConstraint, inspect from sqlalchemy.orm import Mapped, mapped_column from .common import Base -from .types import PORTABLE_BIGINT, StitchJson +from .types import PORTABLE_BIGINT from stitch.api.entities import ( - CCReservoirsSource, - GemSource, IdType, - RMIManualSource, SourceKey, - WMData, - GemData, - RMIManualData, - CCReservoirsData, - WMSource, ) +from .oil_gas_field_source import OilGasFieldSourceModel def float_constraint( colname: str, min_: float | None = None, max_: float | None = None @@ -74,67 +67,24 @@ def from_entity(cls, entity: TModelIn) -> Self: return cls(**filtered) -class GemSourceModel(SourceBase[GemData, GemSource]): - __tablename__ = "gem_sources" - - name: Mapped[str] - country: Mapped[str] - lat: Mapped[float] = mapped_column(lat_constraints("lat")) - lon: Mapped[float] = mapped_column(lon_constraints("lon")) - - -class WMSourceModel(SourceBase[WMData, WMSource]): - __tablename__ = "wm_sources" - - field_name: Mapped[str] - field_country: Mapped[str] - production: Mapped[float] - - -class RMIManualSourceModel(SourceBase[RMIManualData, RMIManualSource]): - __tablename__ = "rmi_manual_sources" - - name_override: Mapped[str | None] - gwp: Mapped[float | None] - gor: Mapped[float | None | None] = mapped_column( - float_constraint("gor", 0, 1), nullable=True - ) - country: Mapped[str | None] - latitude: Mapped[float | None] = mapped_column( - lat_constraints("latitude"), nullable=True - ) - longitude: Mapped[float | None] = mapped_column( - lon_constraints("longitude"), nullable=True - ) - - -class CCReservoirsSourceModel(SourceBase[CCReservoirsData, CCReservoirsSource]): - __tablename__ = "cc_reservoirs_sources" - - name: Mapped[str] - basin: Mapped[str] - depth: Mapped[float] - geofence: Mapped[list[tuple[float, float]]] = mapped_column(StitchJson()) - - SourceModel = ( - GemSourceModel | WMSourceModel | RMIManualSourceModel | CCReservoirsSourceModel + OilGasFieldSourceModel ) SourceModelCls = type[SourceModel] SOURCE_TABLES: Final[Mapping[SourceKey, SourceModelCls]] = { - "gem": GemSourceModel, - "wm": WMSourceModel, - "rmi": RMIManualSourceModel, - "cc": CCReservoirsSourceModel, + "gem": OilGasFieldSourceModel, + "wm": OilGasFieldSourceModel, + "rmi": OilGasFieldSourceModel, + "llm": OilGasFieldSourceModel, } class SourceModelData(TypedDict, total=False): - gem: MutableMapping[IdType, GemSourceModel] - wm: MutableMapping[IdType, WMSourceModel] - cc: MutableMapping[IdType, CCReservoirsSourceModel] - rmi: MutableMapping[IdType, RMIManualSourceModel] + gem: MutableMapping[IdType, OilGasFieldSourceModel] + wm: MutableMapping[IdType, OilGasFieldSourceModel] + cc: MutableMapping[IdType, OilGasFieldSourceModel] + rmi: MutableMapping[IdType, OilGasFieldSourceModel] def empty_source_model_data(): diff --git a/deployments/api/src/stitch/api/entities.py b/deployments/api/src/stitch/api/entities.py index f8451ce..f15b418 100644 --- a/deployments/api/src/stitch/api/entities.py +++ b/deployments/api/src/stitch/api/entities.py @@ -22,7 +22,8 @@ class HasId(Protocol): def id(self) -> IdType: ... -TSourceKey = TypeVar("TSourceKey", bound=OGSISrcKey) +SourceKey = OGSISrcKey +TSourceKey = TypeVar("TSourceKey", bound=SourceKey) class Timestamped(BaseModel): @@ -40,7 +41,7 @@ class SourceBase(BaseModel, Generic[TSourceKey]): class SourceRef(BaseModel): - source: OGSISrcKey + source: SourceKey id: int @@ -68,7 +69,7 @@ class CreateResourceSourceData(BaseModel): og_field: Sequence[OilGasFieldBase | int] = Field(default_factory=list) - def get(self, key: OGSISrcKey): + def get(self, key: SourceKey): raise ValueError(f"Unknown source key: {key}") diff --git a/deployments/api/src/stitch/api/ogsi/entities.py b/deployments/api/src/stitch/api/ogsi/entities.py deleted file mode 100644 index 487336c..0000000 --- a/deployments/api/src/stitch/api/ogsi/entities.py +++ /dev/null @@ -1,18 +0,0 @@ -from __future__ import annotations - -from pydantic import BaseModel, ConfigDict, Field -from stitch.api.resources.entities import CreateResource, Resource as ResourceView - - -class CreateOilGasField(BaseModel): - model_config = ConfigDict(extra="forbid") - resource: CreateResource - owner: str | None = Field(default=None) - operator: str | None = Field(default=None) - - -class OilGasField(BaseModel): - id: int - resource: ResourceView - owner: str | None = None - operator: str | None = None From 7b6a2991a864496c1eac2fd1f3be982d999efdf4 Mon Sep 17 00:00:00 2001 From: Alex Axthelm Date: Thu, 5 Mar 2026 19:12:46 +0100 Subject: [PATCH 23/25] wip: more source memberships --- deployments/api/src/stitch/api/db/init_job.py | 4 ++-- deployments/api/src/stitch/api/db/model/mixins.py | 3 ++- deployments/api/src/stitch/api/db/og_field_source_actions.py | 3 +-- deployments/api/src/stitch/api/db/resource_actions.py | 4 +--- deployments/api/src/stitch/api/routers/resources.py | 2 +- packages/stitch-ogsi/src/stitch/ogsi/model/__init__.py | 5 +---- 6 files changed, 8 insertions(+), 13 deletions(-) diff --git a/deployments/api/src/stitch/api/db/init_job.py b/deployments/api/src/stitch/api/db/init_job.py index be72a6d..45ffdf3 100644 --- a/deployments/api/src/stitch/api/db/init_job.py +++ b/deployments/api/src/stitch/api/db/init_job.py @@ -282,8 +282,8 @@ def create_seed_memberships( sources: list[OilGasFieldSourceModel], ) -> list[MembershipModel]: memberships = [ - MembershipModel.create(user, resources[0], "og_field", 1), - MembershipModel.create(user, resources[1], "og_field", 2), + MembershipModel.create(user, resources[0], "gem", 1), + MembershipModel.create(user, resources[1], "wm", 2), ] for i, mem in enumerate(memberships, start=1): mem.id = i diff --git a/deployments/api/src/stitch/api/db/model/mixins.py b/deployments/api/src/stitch/api/db/model/mixins.py index e689933..f26cf41 100644 --- a/deployments/api/src/stitch/api/db/model/mixins.py +++ b/deployments/api/src/stitch/api/db/model/mixins.py @@ -5,6 +5,7 @@ from sqlalchemy.ext.hybrid import hybrid_property from sqlalchemy.orm import Mapped, declarative_mixin, mapped_column +from stitch.api.entities import SourceBase from .types import StitchJson @@ -29,7 +30,7 @@ class UserAuditMixin: last_updated_by_id: Mapped[int] = mapped_column(ForeignKey("users.id")) -TPayload = TypeVar("TPayload", bound=BaseModel) +TPayload = TypeVar("TPayload", bound=SourceBase) def _extract_payload_type(cls: type) -> type | None: diff --git a/deployments/api/src/stitch/api/db/og_field_source_actions.py b/deployments/api/src/stitch/api/db/og_field_source_actions.py index 81335f1..acd19f5 100644 --- a/deployments/api/src/stitch/api/db/og_field_source_actions.py +++ b/deployments/api/src/stitch/api/db/og_field_source_actions.py @@ -50,7 +50,7 @@ async def attach_to_resource( MembershipModel.create( created_by=created_by, resource=session.get(ResourceModel, resource_id), - source="og_field", + source="rmi", source_pk=source_row.id, ) ) @@ -70,7 +70,6 @@ async def list_og_resources(session): select(ResourceModel) .where(ResourceModel.repointed_id.is_(None)) .join(MembershipModel, MembershipModel.resource_id == ResourceModel.id) - .where(MembershipModel.source == "og_field") .options(selectinload(ResourceModel.memberships)) .distinct() ) diff --git a/deployments/api/src/stitch/api/db/resource_actions.py b/deployments/api/src/stitch/api/db/resource_actions.py index 8a1a34d..1fdabf0 100644 --- a/deployments/api/src/stitch/api/db/resource_actions.py +++ b/deployments/api/src/stitch/api/db/resource_actions.py @@ -47,7 +47,6 @@ def resource_model_to_empty_entity(model: ResourceModel): return Resource( id=model.id, name=model.name, - country=model.country, source_data=SourceData(), constituents=[], created=model.created, @@ -71,7 +70,6 @@ async def resource_model_to_entity( return Resource( id=model.id, name=model.name, - country=model.country, source_data=source_data, constituents=constituents, created=model.created, @@ -110,7 +108,7 @@ async def create(session: AsyncSession, user: CurrentUser, resource: CreateResou - create membership """ model = ResourceModel.create( - created_by=user, name=resource.name, country=resource.country + created_by=user, name=resource.name ) session.add(model) if resource.source_data: diff --git a/deployments/api/src/stitch/api/routers/resources.py b/deployments/api/src/stitch/api/routers/resources.py index a2e7d94..8d63c15 100644 --- a/deployments/api/src/stitch/api/routers/resources.py +++ b/deployments/api/src/stitch/api/routers/resources.py @@ -5,7 +5,7 @@ from stitch.api.db import resource_actions from stitch.api.db.config import UnitOfWorkDep from stitch.api.auth import CurrentUser -from stitch.api.resources.entities import CreateResource, Resource +from stitch.api.entities import CreateResource, Resource router = APIRouter( diff --git a/packages/stitch-ogsi/src/stitch/ogsi/model/__init__.py b/packages/stitch-ogsi/src/stitch/ogsi/model/__init__.py index d991d75..88ae876 100644 --- a/packages/stitch-ogsi/src/stitch/ogsi/model/__init__.py +++ b/packages/stitch-ogsi/src/stitch/ogsi/model/__init__.py @@ -61,10 +61,7 @@ class LLMSource(Source[int, LLMSrcKey], OilGasFieldBase): class OGSourcePayload(SourcePayload): - gem: Sequence[GemSource] = Field(default_factory=lambda: []) - wm: Sequence[WoodMacSource] = Field(default_factory=lambda: []) - rmi: Sequence[RMISource] = Field(default_factory=lambda: []) - llm: Sequence[LLMSource] = Field(default_factory=lambda: []) + og_field: Sequence[OGFieldSource] = Field(default_factory=lambda: []) class OGFieldResource(OilGasFieldBase, Resource[int, OGSourcePayload]): ... From 50423b7c7f1c636e91eeb99df812fc9598df0187 Mon Sep 17 00:00:00 2001 From: Alex Axthelm Date: Thu, 5 Mar 2026 19:14:14 +0100 Subject: [PATCH 24/25] restore mixins --- deployments/api/src/stitch/api/db/model/mixins.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/deployments/api/src/stitch/api/db/model/mixins.py b/deployments/api/src/stitch/api/db/model/mixins.py index f26cf41..b579582 100644 --- a/deployments/api/src/stitch/api/db/model/mixins.py +++ b/deployments/api/src/stitch/api/db/model/mixins.py @@ -1,6 +1,6 @@ from datetime import datetime from typing import Any, ClassVar, Generic, TypeVar, get_args, get_origin -from pydantic import BaseModel, TypeAdapter +from pydantic import TypeAdapter from sqlalchemy import DateTime, ForeignKey, String, func from sqlalchemy.ext.hybrid import hybrid_property @@ -63,9 +63,7 @@ def payload(self) -> TPayload: @payload.inplace.setter def _payload_setter(self, value: TPayload): - # Domain-agnostic: if the payload has a `source` attribute, keep it; - # otherwise set a neutral default. - self.source = getattr(value, "source", "unknown") + self.source = value.source self._payload_data = value.model_dump(mode="json") @payload.inplace.expression From 4a710453fc086ed405b56f0b467740942f23e5c4 Mon Sep 17 00:00:00 2001 From: Alex Axthelm Date: Thu, 5 Mar 2026 19:32:33 +0100 Subject: [PATCH 25/25] wip: source model --- deployments/api/src/stitch/api/db/model/resource.py | 7 +++++-- deployments/api/src/stitch/api/db/model/sources.py | 7 ++----- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/deployments/api/src/stitch/api/db/model/resource.py b/deployments/api/src/stitch/api/db/model/resource.py index bbdf443..a366fdd 100644 --- a/deployments/api/src/stitch/api/db/model/resource.py +++ b/deployments/api/src/stitch/api/db/model/resource.py @@ -25,6 +25,7 @@ from .mixins import TimestampMixin, UserAuditMixin from .types import PORTABLE_BIGINT + class MembershipStatus(StrEnum): ACTIVE = "ACTIVE" INACTIVE = "INACTIVE" @@ -44,7 +45,9 @@ class MembershipModel(TimestampMixin, UserAuditMixin, Base): ) id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) resource_id: Mapped[int] = mapped_column(ForeignKey("resources.id"), nullable=False) - source: Mapped[str] = mapped_column(String(10), nullable=False) # "gem" | "wm" + source: Mapped[SourceKey] = mapped_column( + String(10), nullable=False + ) # "gem" | "wm" source_pk: Mapped[int] = mapped_column(PORTABLE_BIGINT, nullable=False) status: Mapped[MembershipStatus] @@ -53,7 +56,7 @@ def create( cls, created_by: UserEntity, resource: "ResourceModel", - source: str, + source: SourceKey, source_pk: IdType, status: MembershipStatus = MembershipStatus.ACTIVE, ): diff --git a/deployments/api/src/stitch/api/db/model/sources.py b/deployments/api/src/stitch/api/db/model/sources.py index f8dbb5e..da1dede 100644 --- a/deployments/api/src/stitch/api/db/model/sources.py +++ b/deployments/api/src/stitch/api/db/model/sources.py @@ -81,11 +81,8 @@ def from_entity(cls, entity: TModelIn) -> Self: class SourceModelData(TypedDict, total=False): - gem: MutableMapping[IdType, OilGasFieldSourceModel] - wm: MutableMapping[IdType, OilGasFieldSourceModel] - cc: MutableMapping[IdType, OilGasFieldSourceModel] - rmi: MutableMapping[IdType, OilGasFieldSourceModel] + og_field: MutableMapping[IdType, OilGasFieldSourceModel] def empty_source_model_data(): - return SourceModelData(gem={}, wm={}, cc={}, rmi={}) + return SourceModelData(og_field={})