improve metadata handling with ExposureSeries data model#9
improve metadata handling with ExposureSeries data model#9McHaillet wants to merge 4 commits intoteamtomo:mainfrom
Conversation
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #9 +/- ##
===========================================
- Coverage 97.56% 67.79% -29.77%
===========================================
Files 2 3 +1
Lines 82 118 +36
===========================================
Hits 80 80
- Misses 2 38 +36 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
|
@alisterburt @mgiammar @davidetorre99 Any of you guys have some ideas for this? I got stuck on the fact that I want the alignment parameters be optimizable, but this makes the structure a bit wonky. For the ExposureSeries dataclass it feels most intuitive to have a structure like this: from teamtomo_basemodel import BaseModelTeamTomo as BaseModelTT
from datetime import datetime
class Alignment(BaseModelTT):
z_rotation_degrees: float = 0.0
y_rotation_degrees: float = 0.0
x_rotation_degrees: float = 0.0
# global image shifts
y_shift_px: float = 0.0
x_shift_px: float = 0.0
class Exposure(BaseModelTT):
alignment: Alignment
collection_time: datatime
class ExposureSeries(BaseModelTT):
exposures: List[Exposure]However if the parameters should be optimizable, they need to be tensors. I found the following workaround, for example for the z_rotation: class Alignment:
z_rotation_degrees: Annotated[
float | torch.Tensor, AfterValidator(tensor_check)
] = (
Field(alias='tilt_axis_angle', default=torch.tensor(0.0))
)
...The ExposureSeries class could have a function that fetches all the parameters for an optimizer or smh: class ExposureSeries(BaseModelTT):
...
def get_parameters(self, attribute_name):
exposure_parameter_list = [getattr(exposure.alignment, attribute) for exposure in self.exposures]
return exposure_parameter_listMaybe it just doesn't make sense though that the parameters are tensors in this representation |
|
I don't know whether that would be the ideal approach but what if you split the tt basemodel data model for serialization and tensors for optimization (still within tt basemodel but using ExcludedTensor)? After optimization or whatever you just sync the updated tensors back into the serializable floats. For instance, from teamtomo_basemodel import BaseModelTeamTomo as BaseModelTT
from teamtomo_basemodel import ExcludedTensor
from datetime import datetime
class Alignment(BaseModelTT):
z_rotation_degrees: float = Field(default=0.0, description = "z-axis rotation in degrees")
y_rotation_degrees: float = Field(default=0.0, description = "y-axis rotation in degrees")
x_rotation_degrees: float = Field(default=0.0, description = "x-axis rotation in degrees")
# global image shifts
y_shift_px: float = Field(default=0.0, description = "y-axis shift in pixles")
x_shift_px: float = Field(default=0.0, description = "x-axis shift in pixles")
# Optimization tensors, using the ExcludedTensor from teamtomo_basemodel so that they are not serialized in the class
_z_rotation_tensor: ExcludedTensor = 0
_y_rotation_tensor: ExcludedTensor = 0
_x_rotation_tensor: ExcludedTensor = 0
_x_shift_tensor: ExcludedTensor = 0
_y_shift_tensor: ExcludedTensor = 0
# Convert parameters to tensors
def enable_optimization(self) -> None:
self._z_rotation_tensor = torch.tensor(self.z_rotation_degrees)
self._y_rotation_tensor = torch.tensor(self.y_rotation_degrees)
self._x_rotation_tensor = torch.tensor(self.x_rotation_degrees)
self._x_shift_tensor = torch.tensor(self.x_shift_px)
self._y_shift_tensor = torch.tensor(self.y_shift_px)
# Update original values from optimized tensors.
def sync_from_tensors(self) -> None:
if self._z_rotation_tensor is not None:
self.z_rotation_degrees = self._z_rotation_tensor.item()
if self._y_rotation_tensor is not None:
self.y_rotation_degrees = self._y_rotation_tensor.item()
if self._x_rotation_tensor is not None:
self.x_rotation_degrees = self._x_rotation_tensor.item()
if self._y_shift_tensor is not None:
self.y_shift_px = self._y_shift_tensor.item()
if self._x_shift_tensor is not None:
self.x_shift_px = self._x_shift_tensor.item()
# Get all parameter tensors for optimization
def get_parameter_tensors(self) -> List[torch.Tensor]:
tensors = []
for tensor in [self._z_rotation_tensor, self._y_rotation_tensor,
self._x_rotation_tensor, self._y_shift_tensor, self._x_shift_tensor]:
if tensor is not None:
tensors.append(tensor)
return tensors
class Exposure(BaseModelTT):
alignment: Alignment
collection_time: datetime And then for instance the ExposureSeries class can look something like this: class ExposureSeries(BaseModelTT):
exposures: List[Exposure]
# Enable "optimization" to then perform whatever you want to do with the tensors
def enable_optimization(self) -> None:
for exposure in self.exposures:
exposure.alignment.enable_optimization()
# update original parameters after you did whatever you wanted to do
def sync_from_tensors(self) -> None:
for exposure in self.exposures:
exposure.alignment.sync_from_tensors()
# Get parameters
def get_all_parameters(self) -> List[torch.Tensor]:
all_params = []
for exposure in self.exposures:
all_params.extend(exposure.alignment.get_parameter_tensors())
return all_params
# Or if you prefer maybe you can directly parse them as "type" with something like this
def get_parameters_by_type(self, parameter_name: str) -> List[torch.Tensor]:
"""Get specific parameter type across all exposures."""
params = []
for exposure in self.exposures:
tensor = getattr(exposure.alignment, f"_{parameter_name}_tensor", None)
if tensor is not None:
params.append(tensor)
return paramsAgain, not sure whether this is the best way, but that should keep the data structure tidy, and you can still optimize the parameters as needed. Just my two cents here, I’m sure @mgiammar and @alisterburt will have better insights! |
alisterburt
left a comment
There was a problem hiding this comment.
Hey! Sorry for not getting to this earlier, just back from vacation and have a big presentation at the end of the week
Great ideas from both of you here, really appreciate the discussion and I learned a few things (field aliases, excluded tensor type on the teamtomo base model) looking at your implementations which is always nice :-)
I don't immediately have super clear thoughts here but will throw out a few ideas that feel important...
- separation of concerns, what exactly are you modelling? Just alignments? Your whole Tomogram model? The field in a model for alignments shouldn't need to know about things like optimization, if you need that info to describe your model it should be indicated with an extra boolean field like .optimize_shifts
- if these flags start making the model feel too complex then you can always introduce a layer of nesting, something like this maybe?
class Tomogram(BaseModelTT):
alignments: RigidBodyAlignments
optimize_shifts: bool
We should avoid letting models get toooo deep as it can make constructing them annoying but something like this is fine
- reiterating the above point, keep a clear separation between TomogramData and the Tomogram class which implements functionality, Tomogram.load_from_json() would then use the pydantic model/serialization logic internally
- Davide's suggestion of a set of alignments being a List[per-frame alignment] feels like the right way to go I think, it's what pydantic wants you to do - not having tensors/arrays feels annoying at first but that goes away as soon as you implement .as_dataframe on the alignment data structure - this is basically the pattern we fell into with mdocfile/imodmodel and it works well. This model does break down if you have tooooo many instances so be careful the 10k+ realm (atoms in a structure, proteins in a large scale simulation etc)
- relevant to the above, pydantic has a nice TypeAdapter for these Collection[Model] types which let you work with them nicely, example below
class Membrane2D(pydantic.BaseModel):
profile_1d: list[float] # (w, )
weights_1d: list[float] # (w, )
path_control_points: list[tuple[float, float]] # (b, 2)
signal_scale_control_points: list[float] # (b, )
path_is_closed: bool
from pydantic import TypeAdapter
from torch_subtract_membranes_2d.membrane_model import Membrane2D
# adapter class to help with serialization/deserialization
MembraneList = TypeAdapter(List[Membrane2D])
def save_membranes(
membranes: List[Membrane2D], path: os.PathLike
) -> None:
"""Save membranes to json."""
json = MembraneList.dump_json(membranes, indent=2)
with open(path, "wb") as f:
f.write(json)
return
def load_membranes(path: os.PathLike) -> List[Membrane2D]:
"""Load membranes from json."""
with open(path, "rb") as f:
json = f.read()
membranes = MembraneList.validate_json(json)
return membranes-
I don't love the mix of numeric types and tensors - one or the other, preference for simple Python numeric types which are converted to tensors when loaded into your target data structure
-
we should bite the bullet and make sure all shifts are parametrized in angstroms now I think, it's important for making alignments transferable. It means we need to keep track of the pixel spacing but that's fine, we will need pixel spacing for the CTF anyway
If any of this doesn't feel good/clear happy to continue the discussion!
|
Damn, really appreciate all the great input 😍 Need to parse it a bit and fromulate my thoughts, but at quick glance, this
feels good. Just starting with a logical pydantic structure, and then implementing methods to covert it |
|
I think the separation of metadata and functional tensors is a good approach, and what Alister and Davide are sensible implementations. From a more philosophical perspective, I think Pydantic constraints are at first a bit tricky to work around, but I've found these force me to write code with good object-oriented design. It's usually taken me two or three passes before I have something I'm really happy with, and some of the guiding questions I've used for design are:
Hope this helps a bit, and happy to take a closer look next week when I have some more time. |
|
@McHaillet glad this feels like it can work! As always, you're the heaviest user and will likely have the best intuition around how this feels - I totally agree with @mgiammar that some iteration on what "feels good" is part of the process, you'll know when you've hit that magical API feel - I'm really happy that you are thinking about these issues and communicating in the open about them, this is how we all get better :-) side point: Tomogram here shouldn't know anything about optimization, what you're modelling seems to be the config for your optimization process which is itself updating the state of the Tomogram - generally keeping "process config" separate from "object state" feels like a good rule of thumb 🙂 |
|
Thanks for all the great input, this has helped a huge amount. I needed to take a few steps back and define my goals better. I might have made the scope a bit too wide as well, however I still think its good to consider the metadata representation a bit more widely. tilt-series dataclassI want to have a nice way to pass around tilt-series (and frame-series) metadata. Specifically, for this package I need this because it currently only provides reconstruction from a tilt-series to a tomogram based on rigid body alignment parameters, but it should also incorporate CTF and dose weighting (#2). I would like to represent that metadata in a reusable way (both alignment and CTF) as I think it can be something that can be passed around between teamtomo packages. For this, it feels neat to have a pydantic dataclass called alignment friendly tomogram classSo far, I have been using the Tomogram class here to facilitate alignment optimization, and it feels handy to keep using it that way (hence why I was trying to use tensors). Seeing your points, I guess it would make most sense to just assign my ExposureSeries to the Tomogram class, and then handle conversion to tensors in there. Building on Davide's tensor initialization and syncing idea, I came up with this (this is non-pydantic, only the ExposureSeries would be pydantic): class Tomogram:
def __init__(self, tilt_data: ExposureSeries):
self.tilt_data = tilt_data
# initialize to tensors
self._initialize_tensors()
def _initialize_tensors(self):
# just showing it with shifts here
shifts = torch.tensor(self.tilt_data.get_shifts(), dtype=torch.float32) # [N, 2]
self._original_shifts_tensor = shifts
self._shifts_tensor = shifts.clone()
def add_translations(self, translations: torch.Tensor):
# this facilitates updating the translations with a tensor that can be used for
# optimization -> translations will have requires_grad=True
# the same thing could be done for rotations, or defocus ...
self._shifts_tensor = self._original_shifts_tensor + translations
def reconstruct_tomogram(self, shape, pixel_size):
# reconstruct a tomogram at the provided pixel size from the current state of tensors
def sync_from_tensors(self,):
# sync current tensor state to an ExposureSeries Let me know how you guys feel about this |
WIP to have a better internal data model in torch-tomogram. Feel free to give suggestions for whoever is interested :)