-
Notifications
You must be signed in to change notification settings - Fork 0
Add dependency on stitch models for API #29
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
2fa6e45
2d08ac8
5ef1c25
e68fcd7
6ec0a17
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,21 +1,11 @@ | ||
| from .common import Base as StitchBase | ||
| from .sources import ( | ||
| GemSourceModel, | ||
| RMIManualSourceModel, | ||
| CCReservoirsSourceModel, | ||
| WMSourceModel, | ||
| ) | ||
| from .resource import MembershipStatus, MembershipModel, ResourceModel | ||
| from .user import User as UserModel | ||
|
|
||
| __all__ = [ | ||
| "CCReservoirsSourceModel", | ||
| "GemSourceModel", | ||
| "MembershipModel", | ||
| "MembershipStatus", | ||
| "RMIManualSourceModel", | ||
| "ResourceModel", | ||
| "StitchBase", | ||
| "UserModel", | ||
| "WMSourceModel", | ||
| ] |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,4 +1,3 @@ | ||
| from collections import defaultdict | ||
| from enum import StrEnum | ||
| from sqlalchemy import ( | ||
| ForeignKey, | ||
|
|
@@ -12,13 +11,8 @@ | |
| from sqlalchemy.ext.asyncio import AsyncSession | ||
| from sqlalchemy.orm import Mapped, mapped_column, relationship | ||
|
|
||
| from .sources import ( | ||
| SOURCE_TABLES, | ||
| SourceKey, | ||
| SourceModel, | ||
| SourceModelData, | ||
| ) | ||
| from stitch.api.entities import IdType, User as UserEntity | ||
| from stitch.models.types import IdType | ||
| from stitch.api.entities import User as UserEntity | ||
| from .common import Base | ||
| from .mixins import TimestampMixin, UserAuditMixin | ||
| from .types import PORTABLE_BIGINT | ||
|
|
@@ -43,9 +37,7 @@ class MembershipModel(TimestampMixin, UserAuditMixin, Base): | |
| ) | ||
| id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) | ||
| resource_id: Mapped[int] = mapped_column(ForeignKey("resources.id"), nullable=False) | ||
| source: Mapped[SourceKey] = mapped_column( | ||
| String(10), nullable=False | ||
| ) # "gem" | "wm" | ||
| source: Mapped[str] = mapped_column(String(10), nullable=False) # "gem" | "wm" | ||
|
Comment on lines
-46
to
+40
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This does loosen the restrictions somewhat on the valid values, allowing any string under 10 characters. If an invalid value was somehow saved to the db, we'd get an error when trying to call |
||
| source_pk: Mapped[int] = mapped_column(PORTABLE_BIGINT, nullable=False) | ||
| status: Mapped[MembershipStatus] | ||
|
|
||
|
|
@@ -54,14 +46,14 @@ def create( | |
| cls, | ||
| created_by: UserEntity, | ||
| resource: "ResourceModel", | ||
| source: SourceKey, | ||
| source: str, | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This was a place that helps getting editor support and early errors. With the |
||
| source_pk: IdType, | ||
| status: MembershipStatus = MembershipStatus.ACTIVE, | ||
| ): | ||
| model = cls( | ||
| resource_id=resource.id, | ||
| source=source, | ||
| source_pk=str(source_pk), | ||
| source_pk=int(source_pk), | ||
| status=status, | ||
| created_by_id=created_by.id, | ||
| last_updated_by_id=created_by.id, | ||
|
|
@@ -98,23 +90,6 @@ class ResourceModel(TimestampMixin, UserAuditMixin, Base): | |
| # and configure the appropriate SQL statement to load the membership objects | ||
| memberships: Mapped[list[MembershipModel]] = relationship() | ||
|
|
||
| async def get_source_data(self, session: AsyncSession): | ||
| pks_by_src: dict[SourceKey, set[int]] = defaultdict(set) | ||
| for mem in self.memberships: | ||
| if mem.status == MembershipStatus.ACTIVE: | ||
| pks_by_src[mem.source].add(mem.source_pk) | ||
|
|
||
| results: dict[SourceKey, dict[IdType, SourceModel]] = defaultdict(dict) | ||
| for src, pks in pks_by_src.items(): | ||
| model_cls = SOURCE_TABLES.get(src) | ||
| if model_cls is None: | ||
| continue | ||
| stmt = select(model_cls).where(model_cls.id.in_(pks)) | ||
| for src_model in await session.scalars(stmt): | ||
| results[src][src_model.id] = src_model | ||
|
|
||
| return SourceModelData(**results) | ||
|
|
||
| async def get_root(self, session: AsyncSession): | ||
| root = await session.scalar(self.__class__._root_select(self.id)) | ||
| if root is None: | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think most (if not all) of the imports will actually end up coming from
stitch-ogsi.