Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions deployments/api/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ dependencies = [
"pydantic-settings>=2.12.0",
"sqlalchemy>=2.0.44",
"stitch-auth",
"stitch-models",
]

[project.scripts]
Expand Down Expand Up @@ -41,3 +42,4 @@ addopts = ["-v", "--strict-markers", "--tb=short"]

[tool.uv.sources]
stitch-auth = { workspace = true }
stitch-models = { workspace = true }
96 changes: 2 additions & 94 deletions deployments/api/src/stitch/api/db/init_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,12 @@
from sqlalchemy.orm import Session

from stitch.api.db.model import (
CCReservoirsSourceModel,
GemSourceModel,
MembershipModel,
RMIManualSourceModel,
ResourceModel,
StitchBase,
UserModel,
WMSourceModel,
)
from stitch.api.entities import (
GemData,
RMIManualData,
User as UserEntity,
WMData,
)

"""
Expand Down Expand Up @@ -273,81 +265,16 @@ def create_dev_user() -> UserModel:
)


def create_seed_sources():
gem_sources = [
GemSourceModel.from_entity(
GemData(name="Permian Basin Field", country="USA", lat=31.8, lon=-102.3)
),
GemSourceModel.from_entity(
GemData(name="North Sea Platform", country="GBR", lat=57.5, lon=1.5)
),
]
for i, src in enumerate(gem_sources, start=1):
src.id = i

wm_sources = [
WMSourceModel.from_entity(
WMData(
field_name="Eagle Ford Shale", field_country="USA", production=125000.5
)
),
WMSourceModel.from_entity(
WMData(field_name="Ghawar Field", field_country="SAU", production=500000.0)
),
]
for i, src in enumerate(wm_sources, start=1):
src.id = i

rmi_sources = [
RMIManualSourceModel.from_entity(
RMIManualData(
name_override="Custom Override Name",
gwp=25.5,
gor=0.45,
country="CAN",
latitude=56.7,
longitude=-111.4,
)
),
]
for i, src in enumerate(rmi_sources, start=1):
src.id = i

# CC Reservoir sources are intentionally omitted from the dev seed profile;
# the CCReservoirsSourceModel table is still created from SQLAlchemy metadata.
cc_sources: list[CCReservoirsSourceModel] = []

return gem_sources, wm_sources, rmi_sources, cc_sources


def create_seed_resources(user: UserEntity) -> list[ResourceModel]:
resources = [
ResourceModel.create(user, name="Multi-Source Asset", country="USA"),
ResourceModel.create(user, name="Single Source Asset", country="GBR"),
ResourceModel.create(user, name="Resource Foo01"),
ResourceModel.create(user, name="Resource Bar01"),
]
for i, res in enumerate(resources, start=1):
res.id = i
return resources


def create_seed_memberships(
user: UserEntity,
resources: list[ResourceModel],
gem_sources: list[GemSourceModel],
wm_sources: list[WMSourceModel],
rmi_sources: list[RMIManualSourceModel],
) -> list[MembershipModel]:
memberships = [
MembershipModel.create(user, resources[0], "gem", gem_sources[0].id),
MembershipModel.create(user, resources[0], "wm", wm_sources[0].id),
MembershipModel.create(user, resources[0], "rmi", rmi_sources[0].id),
MembershipModel.create(user, resources[1], "gem", gem_sources[1].id),
]
for i, mem in enumerate(memberships, start=1):
mem.id = i
return memberships


def reset_sequences(engine, tables: Iterable[str]) -> None:
with engine.begin() as conn:
for t in tables:
Expand Down Expand Up @@ -376,34 +303,15 @@ def seed_dev(engine) -> None:
name=user_model.name,
)

dev_entity = UserEntity(
id=dev_model.id,
sub=dev_model.sub,
email=dev_model.email,
name=dev_model.name,
)

gem_sources, wm_sources, rmi_sources, cc_sources = create_seed_sources()
session.add_all(gem_sources + wm_sources + rmi_sources + cc_sources)

resources = create_seed_resources(user_entity)
resources = create_seed_resources(dev_entity)
session.add_all(resources)

memberships = create_seed_memberships(
user_entity, resources, gem_sources, wm_sources, rmi_sources
)
session.add_all(memberships)

session.commit()

reset_sequences(
engine,
tables=[
"users",
"gem_sources",
"wm_sources",
"rmi_manual_sources",
"resources",
"memberships",
],
Expand Down
10 changes: 0 additions & 10 deletions deployments/api/src/stitch/api/db/model/__init__.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,11 @@
from .common import Base as StitchBase
from .sources import (
GemSourceModel,
RMIManualSourceModel,
CCReservoirsSourceModel,
WMSourceModel,
)
from .resource import MembershipStatus, MembershipModel, ResourceModel
from .user import User as UserModel

__all__ = [
"CCReservoirsSourceModel",
"GemSourceModel",
"MembershipModel",
"MembershipStatus",
"RMIManualSourceModel",
"ResourceModel",
"StitchBase",
"UserModel",
"WMSourceModel",
]
9 changes: 5 additions & 4 deletions deployments/api/src/stitch/api/db/model/mixins.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
from datetime import datetime
from typing import Any, ClassVar, Generic, TypeVar, get_args, get_origin
from pydantic import TypeAdapter
from pydantic import BaseModel, TypeAdapter
from sqlalchemy import DateTime, ForeignKey, String, func
from sqlalchemy.ext.hybrid import hybrid_property

from sqlalchemy.orm import Mapped, declarative_mixin, mapped_column
from stitch.api.entities import SourceBase
from .types import StitchJson


Expand All @@ -30,7 +29,7 @@ class UserAuditMixin:
last_updated_by_id: Mapped[int] = mapped_column(ForeignKey("users.id"))


TPayload = TypeVar("TPayload", bound=SourceBase)
TPayload = TypeVar("TPayload", bound=BaseModel)


def _extract_payload_type(cls: type) -> type | None:
Expand Down Expand Up @@ -63,7 +62,9 @@ def payload(self) -> TPayload:

@payload.inplace.setter
def _payload_setter(self, value: TPayload):
self.source = value.source
# Domain-agnostic: if the payload has a `source` attribute, keep it;
# otherwise set a neutral default.
self.source = getattr(value, "source", "unknown")
self._payload_data = value.model_dump(mode="json")

@payload.inplace.expression
Expand Down
35 changes: 5 additions & 30 deletions deployments/api/src/stitch/api/db/model/resource.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from collections import defaultdict
from enum import StrEnum
from sqlalchemy import (
ForeignKey,
Expand All @@ -12,13 +11,8 @@
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Mapped, mapped_column, relationship

from .sources import (
SOURCE_TABLES,
SourceKey,
SourceModel,
SourceModelData,
)
from stitch.api.entities import IdType, User as UserEntity
from stitch.models.types import IdType
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think most (if not all) of the imports will actually end up coming from stitch-ogsi.

from stitch.api.entities import User as UserEntity
from .common import Base
from .mixins import TimestampMixin, UserAuditMixin
from .types import PORTABLE_BIGINT
Expand All @@ -43,9 +37,7 @@ class MembershipModel(TimestampMixin, UserAuditMixin, Base):
)
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
resource_id: Mapped[int] = mapped_column(ForeignKey("resources.id"), nullable=False)
source: Mapped[SourceKey] = mapped_column(
String(10), nullable=False
) # "gem" | "wm"
source: Mapped[str] = mapped_column(String(10), nullable=False) # "gem" | "wm"
Comment on lines -46 to +40
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This does loosen the restrictions somewhat on the valid values, allowing any string under 10 characters.

If an invalid value was somehow saved to the db, we'd get an error when trying to call model_validate when fetching it from the db.

source_pk: Mapped[int] = mapped_column(PORTABLE_BIGINT, nullable=False)
status: Mapped[MembershipStatus]

Expand All @@ -54,14 +46,14 @@ def create(
cls,
created_by: UserEntity,
resource: "ResourceModel",
source: SourceKey,
source: str,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was a place that helps getting editor support and early errors. With the SrcKey (now OGSISrcKey from stitch-ogsi), you'll get an error diagnostic if you try to call Resource.create(source="some_source", ...).

source_pk: IdType,
status: MembershipStatus = MembershipStatus.ACTIVE,
):
model = cls(
resource_id=resource.id,
source=source,
source_pk=str(source_pk),
source_pk=int(source_pk),
status=status,
created_by_id=created_by.id,
last_updated_by_id=created_by.id,
Expand Down Expand Up @@ -98,23 +90,6 @@ class ResourceModel(TimestampMixin, UserAuditMixin, Base):
# and configure the appropriate SQL statement to load the membership objects
memberships: Mapped[list[MembershipModel]] = relationship()

async def get_source_data(self, session: AsyncSession):
pks_by_src: dict[SourceKey, set[int]] = defaultdict(set)
for mem in self.memberships:
if mem.status == MembershipStatus.ACTIVE:
pks_by_src[mem.source].add(mem.source_pk)

results: dict[SourceKey, dict[IdType, SourceModel]] = defaultdict(dict)
for src, pks in pks_by_src.items():
model_cls = SOURCE_TABLES.get(src)
if model_cls is None:
continue
stmt = select(model_cls).where(model_cls.id.in_(pks))
for src_model in await session.scalars(stmt):
results[src][src_model.id] = src_model

return SourceModelData(**results)

async def get_root(self, session: AsyncSession):
root = await session.scalar(self.__class__._root_select(self.id))
if root is None:
Expand Down
Loading