Skip to content

Comments

improve metadata handling with ExposureSeries data model#9

Draft
McHaillet wants to merge 4 commits intoteamtomo:mainfrom
McHaillet:better_metadata
Draft

improve metadata handling with ExposureSeries data model#9
McHaillet wants to merge 4 commits intoteamtomo:mainfrom
McHaillet:better_metadata

Conversation

@McHaillet
Copy link
Contributor

WIP to have a better internal data model in torch-tomogram. Feel free to give suggestions for whoever is interested :)

@codecov
Copy link

codecov bot commented Sep 5, 2025

Codecov Report

❌ Patch coverage is 0% with 36 lines in your changes missing coverage. Please review.
✅ Project coverage is 67.79%. Comparing base (0444d1d) to head (bd46369).

Files with missing lines Patch % Lines
src/torch_tomogram/datamodel.py 0.00% 36 Missing ⚠️

❗ There is a different number of reports uploaded between BASE (0444d1d) and HEAD (bd46369). Click for more details.

HEAD has 2 uploads less than BASE
Flag BASE (0444d1d) HEAD (bd46369)
3 1
Additional details and impacted files
@@             Coverage Diff             @@
##             main       #9       +/-   ##
===========================================
- Coverage   97.56%   67.79%   -29.77%     
===========================================
  Files           2        3        +1     
  Lines          82      118       +36     
===========================================
  Hits           80       80               
- Misses          2       38       +36     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@McHaillet
Copy link
Contributor Author

@alisterburt @mgiammar @davidetorre99 Any of you guys have some ideas for this?

I got stuck on the fact that I want the alignment parameters be optimizable, but this makes the structure a bit wonky. For the ExposureSeries dataclass it feels most intuitive to have a structure like this:

from teamtomo_basemodel import BaseModelTeamTomo as BaseModelTT
from datetime import datetime

class Alignment(BaseModelTT):
    z_rotation_degrees: float = 0.0
    y_rotation_degrees: float = 0.0
    x_rotation_degrees: float = 0.0

    # global image shifts
    y_shift_px: float = 0.0
    x_shift_px: float = 0.0

class Exposure(BaseModelTT):
    alignment: Alignment
    collection_time: datatime

class ExposureSeries(BaseModelTT):
    exposures: List[Exposure]

However if the parameters should be optimizable, they need to be tensors. I found the following workaround, for example for the z_rotation:

class Alignment:
    z_rotation_degrees: Annotated[
            float | torch.Tensor, AfterValidator(tensor_check)
        ] = (
            Field(alias='tilt_axis_angle', default=torch.tensor(0.0))
        )
    ...

The ExposureSeries class could have a function that fetches all the parameters for an optimizer or smh:

class ExposureSeries(BaseModelTT):
    ...
    def get_parameters(self, attribute_name):
        exposure_parameter_list = [getattr(exposure.alignment, attribute) for exposure in self.exposures]
        return exposure_parameter_list

Maybe it just doesn't make sense though that the parameters are tensors in this representation

@davidetorre99
Copy link
Contributor

I don't know whether that would be the ideal approach but what if you split the tt basemodel data model for serialization and tensors for optimization (still within tt basemodel but using ExcludedTensor)? After optimization or whatever you just sync the updated tensors back into the serializable floats. For instance,

from teamtomo_basemodel import BaseModelTeamTomo as BaseModelTT
from teamtomo_basemodel import ExcludedTensor
from datetime import datetime

class Alignment(BaseModelTT):
    z_rotation_degrees: float = Field(default=0.0, description = "z-axis rotation in degrees")
    y_rotation_degrees: float = Field(default=0.0, description = "y-axis rotation in degrees")
    x_rotation_degrees: float = Field(default=0.0, description = "x-axis rotation in degrees")

    # global image shifts
    y_shift_px: float =  Field(default=0.0, description = "y-axis shift in pixles")
    x_shift_px: float = Field(default=0.0, description = "x-axis shift in pixles")

    # Optimization tensors, using the ExcludedTensor from teamtomo_basemodel so that they are not serialized in the class
    _z_rotation_tensor: ExcludedTensor = 0
    _y_rotation_tensor: ExcludedTensor = 0
    _x_rotation_tensor: ExcludedTensor = 0
    _x_shift_tensor: ExcludedTensor = 0
    _y_shift_tensor: ExcludedTensor = 0

    # Convert parameters to tensors
    def enable_optimization(self) -> None:
           self._z_rotation_tensor = torch.tensor(self.z_rotation_degrees)
           self._y_rotation_tensor = torch.tensor(self.y_rotation_degrees)
           self._x_rotation_tensor = torch.tensor(self.x_rotation_degrees)
           self._x_shift_tensor = torch.tensor(self.x_shift_px)
           self._y_shift_tensor = torch.tensor(self.y_shift_px)

    #  Update original values from optimized tensors.
    def sync_from_tensors(self) -> None:
        if self._z_rotation_tensor is not None:
            self.z_rotation_degrees = self._z_rotation_tensor.item()
        if self._y_rotation_tensor is not None:
            self.y_rotation_degrees = self._y_rotation_tensor.item()
        if self._x_rotation_tensor is not None:
            self.x_rotation_degrees = self._x_rotation_tensor.item()
        if self._y_shift_tensor is not None:
            self.y_shift_px = self._y_shift_tensor.item()
        if self._x_shift_tensor is not None:
            self.x_shift_px = self._x_shift_tensor.item()

    # Get all parameter tensors for optimization
    def get_parameter_tensors(self) -> List[torch.Tensor]:
        tensors = []
        for tensor in [self._z_rotation_tensor, self._y_rotation_tensor, 
                      self._x_rotation_tensor, self._y_shift_tensor, self._x_shift_tensor]:
            if tensor is not None:
                tensors.append(tensor)
        return tensors
    
class Exposure(BaseModelTT):
    alignment: Alignment
    collection_time: datetime 

And then for instance the ExposureSeries class can look something like this:

class ExposureSeries(BaseModelTT):
    exposures: List[Exposure]
    
     # Enable "optimization" to then perform whatever you want to do with the tensors
    def enable_optimization(self) -> None:
        for exposure in self.exposures:
            exposure.alignment.enable_optimization()

    # update original parameters after you did whatever you wanted to do
    def sync_from_tensors(self) -> None:
        for exposure in self.exposures:
            exposure.alignment.sync_from_tensors()

    # Get parameters
    def get_all_parameters(self) -> List[torch.Tensor]:
        all_params = []
        for exposure in self.exposures:
            all_params.extend(exposure.alignment.get_parameter_tensors())
        return all_params
    
    # Or if you prefer maybe you can directly parse them as "type" with something like this
    def get_parameters_by_type(self, parameter_name: str) -> List[torch.Tensor]:
        """Get specific parameter type across all exposures."""
        params = []
        for exposure in self.exposures:
            tensor = getattr(exposure.alignment, f"_{parameter_name}_tensor", None)
            if tensor is not None:
                params.append(tensor)
        return params

Again, not sure whether this is the best way, but that should keep the data structure tidy, and you can still optimize the parameters as needed. Just my two cents here, I’m sure @mgiammar and @alisterburt will have better insights!

Copy link

@alisterburt alisterburt left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey! Sorry for not getting to this earlier, just back from vacation and have a big presentation at the end of the week

Great ideas from both of you here, really appreciate the discussion and I learned a few things (field aliases, excluded tensor type on the teamtomo base model) looking at your implementations which is always nice :-)

I don't immediately have super clear thoughts here but will throw out a few ideas that feel important...

  • separation of concerns, what exactly are you modelling? Just alignments? Your whole Tomogram model? The field in a model for alignments shouldn't need to know about things like optimization, if you need that info to describe your model it should be indicated with an extra boolean field like .optimize_shifts
  • if these flags start making the model feel too complex then you can always introduce a layer of nesting, something like this maybe?

class Tomogram(BaseModelTT):
alignments: RigidBodyAlignments
optimize_shifts: bool

We should avoid letting models get toooo deep as it can make constructing them annoying but something like this is fine

  • reiterating the above point, keep a clear separation between TomogramData and the Tomogram class which implements functionality, Tomogram.load_from_json() would then use the pydantic model/serialization logic internally
  • Davide's suggestion of a set of alignments being a List[per-frame alignment] feels like the right way to go I think, it's what pydantic wants you to do - not having tensors/arrays feels annoying at first but that goes away as soon as you implement .as_dataframe on the alignment data structure - this is basically the pattern we fell into with mdocfile/imodmodel and it works well. This model does break down if you have tooooo many instances so be careful the 10k+ realm (atoms in a structure, proteins in a large scale simulation etc)
  • relevant to the above, pydantic has a nice TypeAdapter for these Collection[Model] types which let you work with them nicely, example below
class Membrane2D(pydantic.BaseModel):
    profile_1d: list[float]  # (w, )
    weights_1d: list[float]  # (w, )
    path_control_points: list[tuple[float, float]]  # (b, 2)
    signal_scale_control_points: list[float]  # (b, )
    path_is_closed: bool

from pydantic import TypeAdapter

from torch_subtract_membranes_2d.membrane_model import Membrane2D

# adapter class to help with serialization/deserialization
MembraneList = TypeAdapter(List[Membrane2D])


def save_membranes(
    membranes: List[Membrane2D], path: os.PathLike
) -> None:
    """Save membranes to json."""
    json = MembraneList.dump_json(membranes, indent=2)
    with open(path, "wb") as f:
        f.write(json)
    return


def load_membranes(path: os.PathLike) -> List[Membrane2D]:
    """Load membranes from json."""
    with open(path, "rb") as f:
        json = f.read()
    membranes = MembraneList.validate_json(json)
    return membranes
  • I don't love the mix of numeric types and tensors - one or the other, preference for simple Python numeric types which are converted to tensors when loaded into your target data structure

  • we should bite the bullet and make sure all shifts are parametrized in angstroms now I think, it's important for making alignments transferable. It means we need to keep track of the pixel spacing but that's fine, we will need pixel spacing for the CTF anyway

If any of this doesn't feel good/clear happy to continue the discussion!

@McHaillet
Copy link
Contributor Author

Damn, really appreciate all the great input 😍 Need to parse it a bit and fromulate my thoughts, but at quick glance, this

  • Davide's suggestion of a set of alignments being a List[per-frame alignment] feels like the right way to go I think, it's what pydantic wants you to do - not having tensors/arrays feels annoying at first but that goes away as soon as you implement .as_dataframe on the alignment data structure

feels good. Just starting with a logical pydantic structure, and then implementing methods to covert it as_tensor, or as_dataframe.

@mgiammar
Copy link

I think the separation of metadata and functional tensors is a good approach, and what Alister and Davide are sensible implementations.

From a more philosophical perspective, I think Pydantic constraints are at first a bit tricky to work around, but I've found these force me to write code with good object-oriented design. It's usually taken me two or three passes before I have something I'm really happy with, and some of the guiding questions I've used for design are:

  • What (meta)data am I often loading into Python, and how does it makes sense to split up? (e.g. per-exposure in this case)
  • How am I accessing this data in the functional potion of my code? --> what helper methods do I need to define?
  • Where can I abstract away reused structures, and how do the object hierarchies change
  • What (meta)data am I exporting after computation? Does it make sense to have some *Result class, or should the Python object be updated in-place before export?

Hope this helps a bit, and happy to take a closer look next week when I have some more time.

@alisterburt
Copy link

@McHaillet glad this feels like it can work! As always, you're the heaviest user and will likely have the best intuition around how this feels - I totally agree with @mgiammar that some iteration on what "feels good" is part of the process, you'll know when you've hit that magical API feel - I'm really happy that you are thinking about these issues and communicating in the open about them, this is how we all get better :-)

side point: Tomogram here shouldn't know anything about optimization, what you're modelling seems to be the config for your optimization process which is itself updating the state of the Tomogram - generally keeping "process config" separate from "object state" feels like a good rule of thumb 🙂

@McHaillet
Copy link
Contributor Author

McHaillet commented Sep 12, 2025

Thanks for all the great input, this has helped a huge amount. I needed to take a few steps back and define my goals better. I might have made the scope a bit too wide as well, however I still think its good to consider the metadata representation a bit more widely.

tilt-series dataclass

I want to have a nice way to pass around tilt-series (and frame-series) metadata. Specifically, for this package I need this because it currently only provides reconstruction from a tilt-series to a tomogram based on rigid body alignment parameters, but it should also incorporate CTF and dose weighting (#2). I would like to represent that metadata in a reusable way (both alignment and CTF) as I think it can be something that can be passed around between teamtomo packages.

For this, it feels neat to have a pydantic dataclass called ExposureSeries (and maybe a TiltSeries dataclass that inherits from this), and I agree with you guys' points that it should just be simple Python numeric types and not tensors.

alignment friendly tomogram class

So far, I have been using the Tomogram class here to facilitate alignment optimization, and it feels handy to keep using it that way (hence why I was trying to use tensors). Seeing your points, I guess it would make most sense to just assign my ExposureSeries to the Tomogram class, and then handle conversion to tensors in there.

Building on Davide's tensor initialization and syncing idea, I came up with this (this is non-pydantic, only the ExposureSeries would be pydantic):

class Tomogram:
    def __init__(self, tilt_data: ExposureSeries):
        self.tilt_data = tilt_data
        
        # initialize to tensors
        self._initialize_tensors()

    def _initialize_tensors(self):
        # just showing it with shifts here
        shifts = torch.tensor(self.tilt_data.get_shifts(), dtype=torch.float32)  # [N, 2]
        self._original_shifts_tensor = shifts
        self._shifts_tensor = shifts.clone()

    def add_translations(self, translations: torch.Tensor):
        # this facilitates updating the translations with a tensor that can be used for 
        # optimization -> translations will have requires_grad=True
        # the same thing could be done for rotations, or defocus ...
        self._shifts_tensor = self._original_shifts_tensor + translations
    
    def reconstruct_tomogram(self, shape, pixel_size):
        # reconstruct a tomogram at the provided pixel size from the current state of tensors 

    def sync_from_tensors(self,):
        # sync current tensor state to an ExposureSeries 

Let me know how you guys feel about this

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants