Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions pyispyb/app/extensions/database/definitions.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
from typing import Optional, Any
from fastapi import HTTPException

import sqlalchemy
from sqlalchemy.orm import joinedload
Expand All @@ -9,6 +10,7 @@
from pyispyb.app.globals import g
from pyispyb.app.extensions.database.middleware import db


logger = logging.getLogger(__name__)

_session = sqlalchemy.func.concat(
Expand Down Expand Up @@ -52,6 +54,42 @@ def get_options() -> Options:
return app.db_options


def authorize_for_proposal(proposalId: int) -> True:
query = db.session.query(models.Proposal).filter(
models.Proposal.proposalId == proposalId
)
query = with_authorization_proposal(query)
res = query.count()
if res == 0:
raise HTTPException(
status_code=403, detail="User is not authorized for proposal"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here i would return 404 Proposal not found and log that the user was attempting to access a proposal for which they didn't have access

)


def with_authorization_proposal(
query: "sqlalchemy.orm.Query[Any]",
includeArchived: bool = False,
):
return with_authorization(
query=query,
includeArchived=includeArchived,
proposalColumn=None,
joinBLSession=True,
)


def with_authorization_session(
query: "sqlalchemy.orm.Query[Any]",
includeArchived: bool = False,
):
return with_authorization(
query=query,
includeArchived=includeArchived,
proposalColumn=models.BLSession.proposalId,
joinBLSession=False,
)


def with_authorization(
query: "sqlalchemy.orm.Query[Any]",
includeArchived: bool = False,
Expand Down
8 changes: 6 additions & 2 deletions pyispyb/app/extensions/database/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,11 +106,15 @@ def with_metadata(
return parsed


def update_model(model: any, values: dict[str, any]):
def update_model(model: any, values: dict[str, any], nested=True):
"""Update a model with new values including nested models"""
for key, value in values.items():
if isinstance(value, dict):
update_model(getattr(model, key), value)
if nested:
update_model(getattr(model, key), value)
elif isinstance(value, list):
if nested:
raise NotImplementedError("Need to implement nested list update")
else:
if isinstance(value, enum.Enum):
value = value.value
Expand Down
251 changes: 239 additions & 12 deletions pyispyb/core/modules/samples.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
import enum
from typing import Optional
from fastapi import HTTPException

from sqlalchemy.orm import contains_eager, aliased, joinedload
from sqlalchemy.sql.expression import func, distinct, and_, literal_column
from ispyb import models

from ...config import settings
from ...app.extensions.database.definitions import with_authorization
from ...app.extensions.database.definitions import (
authorize_for_proposal,
with_authorization,
with_authorization_proposal,
)
from ...app.extensions.database.middleware import db
from ...app.extensions.database.utils import (
Paged,
Expand Down Expand Up @@ -117,7 +122,7 @@ def get_samples(
models.AutoProcScalingHasInt.autoProcScalingId
== models.AutoProcScalingStatistics.autoProcScalingId,
)
.join(
.outerjoin(
models.Container,
models.BLSample.containerId == models.Container.containerId,
)
Expand All @@ -126,7 +131,7 @@ def get_samples(
models.Container.code,
)
)
.join(models.Dewar, models.Container.dewarId == models.Dewar.dewarId)
.outerjoin(models.Dewar, models.Container.dewarId == models.Dewar.dewarId)
.options(
contains_eager(
models.BLSample.Container,
Expand All @@ -135,15 +140,17 @@ def get_samples(
models.Dewar.code,
)
)
.join(models.Shipping, models.Dewar.shippingId == models.Shipping.shippingId)
.outerjoin(
models.Shipping, models.Dewar.shippingId == models.Shipping.shippingId
)
.options(
contains_eager(
models.BLSample.Container, models.Container.Dewar, models.Dewar.Shipping
).load_only(
models.Shipping.shippingName,
)
)
.join(models.Proposal, models.Proposal.proposalId == models.Shipping.proposalId)
.join(models.Proposal, models.Proposal.proposalId == models.Protein.proposalId)
.group_by(models.BLSample.blSampleId)
)

Expand Down Expand Up @@ -238,24 +245,204 @@ def get_samples(
return Paged(total=total, results=results, skip=skip, limit=limit)


def build_compositions(
composition_model,
compositions: list[schema.CompositionCreate | schema.Composition | None] | None,
proposal: models.Proposal,
):
res = []
if compositions is None:
return res
for c in compositions:
if c is not None:
component: models.Component = None
# Try to find component in DB
if isinstance(c.Component, schema.Component):
component = with_authorization_proposal(
db.session.query(models.Component)
.filter(models.Component.componentId == c.Component.componentId)
.join(models.Proposal)
.filter(models.Proposal.proposalId == proposal.proposalId)
).first()
if component is None:
raise HTTPException(
status_code=422,
detail=f"Could not find component with id {c.Component.componentId}",
)
# If c.Component is ComponentCreate, try to find same component to avoid duplicate
else:
# Try to find component type in DB
component_type: models.ComponentType = (
db.session.query(models.ComponentType)
.filter(models.ComponentType.name == c.Component.ComponentType.name)
.first()
)
# If component_type found, try to find component
if component_type is not None:
component = (
db.session.query(models.Component)
.filter(models.Component.name == c.Component.name)
.filter(models.Component.ComponentType == component_type)
.filter(models.Component.Proposal == proposal)
.first()
)
# If no component type found, create
else:
component_type = models.ComponentType(
**c.Component.ComponentType.dict()
)

# If no component found, create
if component is None:
component = models.Component(
**{
**c.Component.dict(),
"ComponentType": component_type,
"Proposal": proposal,
"componentId": None,
}
)

# find concentration type in DB
concentration_type = None
if c.ConcentrationType is not None:
concentration_type = (
db.session.query(models.ConcentrationType)
.filter(
models.ConcentrationType.concentrationTypeId
== c.ConcentrationType.concentrationTypeId
)
.first()
)
if concentration_type is None:
raise HTTPException(
status_code=422,
detail=f"Could not find concentration_type with id {c.ConcentrationType.concentrationTypeId}",
)

# create final composition object
composition = composition_model(
**{
**c.dict(),
"Component": component,
"ConcentrationType": concentration_type,
}
)
res.append(composition)
return res


def build_crystal(sample: schema.SampleCreate | schema.SampleUpdate) -> models.Crystal:
crystal: models.Crystal = None
if isinstance(sample.Crystal, schema.SampleCrystalUpdate):
crystal = with_authorization_proposal(
db.session.query(models.Crystal)
.filter(models.Crystal.crystalId == sample.Crystal.crystalId)
.join(models.Protein)
.join(models.Proposal)
).first()
if crystal is None:
raise HTTPException(
status_code=422,
detail=f"Could not find Crystal with id {sample.Crystal.crystalId}",
)
update_model(crystal, sample.Crystal.dict(exclude_unset=True), nested=False)
else:
# Create new crystal
protein = with_authorization_proposal(
db.session.query(models.Protein)
.filter(models.Protein.proteinId == sample.Crystal.Protein.proteinId)
.join(models.Proposal)
).first()
if protein is None:
raise HTTPException(
status_code=422,
detail=f"Could not find protein with id {sample.Crystal.Protein.proteinId}",
)
crystal = models.Crystal(
**{**sample.Crystal.dict(), "Protein": protein, "crystal_compositions": []}
)

proposal = crystal.Protein.Proposal

crystal.crystal_compositions = build_compositions(
models.CrystalComposition, sample.Crystal.crystal_compositions, proposal
)

return crystal


def create_sample(sample: schema.SampleCreate) -> models.BLSample:
sample_dict = sample.dict()
sample = models.BLSample(**sample_dict)
db.session.add(sample)
crystal = build_crystal(sample)

proposal = crystal.Protein.Proposal

authorize_for_proposal(proposal.proposalId)

sample_compositions = build_compositions(
models.SampleComposition, sample.sample_compositions, proposal
)

new_sample = models.BLSample(
**{
**sample_dict,
"Crystal": crystal,
"sample_compositions": sample_compositions,
}
)

db.session.add(new_sample)
db.session.commit()

new_sample = get_samples(sampleId=sample.sampleId, skip=0, limit=1)
new_sample = get_samples(blSampleId=new_sample.blSampleId, skip=0, limit=1)
return new_sample.first


def update_sample(sampleId: int, sample: schema.SampleCreate) -> models.BLSample:
def update_sample(blSampleId: int, sample: schema.SampleUpdate) -> models.BLSample:
sample_dict = sample.dict(exclude_unset=True)
new_sample = get_samples(sampleId=sampleId, skip=0, limit=1).first
old_sample = get_samples(blSampleId=blSampleId, skip=0, limit=1).first
update_model(old_sample, sample_dict, nested=False)

crystal = build_crystal(sample)
old_sample.Crystal = crystal

proposal = crystal.Protein.Proposal
authorize_for_proposal(proposal.proposalId)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I dont think you should need this.

old_sample = get_samples(blSampleId=blSampleId, skip=0, limit=1).first

will check the authorization of the current user vs the sample and return an empty list if they dont have access, thus throwing IndexError which the route function will catch


old_sample.sample_compositions = build_compositions(
models.SampleComposition, sample.sample_compositions, proposal
)

update_model(new_sample, sample_dict)
db.session.commit()

return get_samples(sampleId=sampleId, skip=0, limit=1).first
return get_samples(blSampleId=sample.blSampleId, skip=0, limit=1).first


def delete_sample(
blSampleId: int,
) -> None:
sample = get_samples(blSampleId=blSampleId, skip=0, limit=1).first
if sample._metadata["datacollections"] > 0:
raise HTTPException(
status_code=409,
detail="Sample cannot be deleted because it is associated with data collections",
)

if sample._metadata["subsamples"] > 0:
raise HTTPException(
status_code=409,
detail="Sample cannot be deleted because it is associated with sub samples",
)

if sample._metadata["autoIntegrations"] > 0:
raise HTTPException(
status_code=409,
detail="Sample cannot be deleted because it is associated autoIntegrations",
)

db.session.delete(sample)
db.session.commit()


SUBSAMPLE_ORDER_BY_MAP = {
Expand Down Expand Up @@ -420,3 +607,43 @@ def get_sample_images(
results = with_metadata(query.all(), list(metadata.keys()))

return Paged(total=total, results=results, skip=skip, limit=limit)


def get_components(
skip: int,
limit: int,
proposal: Optional[str] = None,
) -> Paged[models.Component]:

query = db.session.query(models.Component).join(models.Proposal)

if proposal:
proposal_row = (
db.session.query(models.Proposal)
.filter(models.Proposal.proposal == proposal)
.first()
)
if proposal_row:
query = query.filter(models.Proposal.proposalId == proposal_row.proposalId)

query = with_authorization(query)

total = query.count()
query = page(query, skip=skip, limit=limit)
results = query.all()

return Paged(total=total, results=results, skip=skip, limit=limit)


def get_component_types() -> list[models.ComponentType]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe paginate these responses, if only for consistency?


query = db.session.query(models.ComponentType)
results = query.all()

return results


def get_concentration_types() -> list[models.ConcentrationType]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above, maybe paginate these responses, if only for consistency?

query = db.session.query(models.ConcentrationType)
results = query.all()
return results
Loading