Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 10 additions & 10 deletions Bioimage-io-scripts/UNet3DKinetochores.model.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,10 @@ timestamp: 2019-12-11T12:22:32Z # ISO 8601
inputs:
- name: raw
description: raw input
axes: bczyx # letters of axes in btczyx
axes: czyx # letters of axes in btczyx
data_type: float32
data_range: [-inf, inf]
shape: [1, 1, 48, 128, 128]
shape: [1, 48, 128, 128]
preprocessing: # list of preprocessing steps
- name: zero_mean_unit_variance # name of preprocessing step
kwargs:
Expand All @@ -45,14 +45,14 @@ inputs:
outputs:
- name: probability
description: probability in [0,1]
axes: bczyx
axes: czyx
data_type: float32
data_range: [-inf, inf]
halo: [0, 0, 32, 48, 48]
halo: [0, 32, 48, 48]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is really so much of the returned output affected by edge artefacts?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is based on the settings in Pytorch-3DUNet, which worked well for Kinetochores use case (training as well as inference). But do you suggest to go for smaller values? (I can try that)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could this be the halo within the network? You have your slicer, etc. to go over the whole volume. Just make sure that this is actually specifying the final output and not an intermediate step within your algorithm. If that's the case leave it as is and let's get this working before we start tweaking things.

Copy link
Collaborator Author

@kiryteo kiryteo Dec 18, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I checked for this and it is part of the predictor config and the routine.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

my understanding was that as this slicer is only used within your model it has no direct influence over the overall in- and output of the whole bioimage.io model. Let's take a closer look at this when we have a more or less running example.

shape:
reference_input: raw
scale: [1, 1, 1, 1, 1]
offset: [0, 0, 0, 0, 0]
scale: [1, 1, 1, 1]
offset: [0, 0, 0, 0]

language: python
framework: pytorch
Expand All @@ -63,11 +63,11 @@ dependencies: conda:../environment.yaml
test_inputs: [test_input.npy]
test_outputs: [test_output.npy]

sample_inputs: [sample_input.npy]
sample_outputs: [sample_output.npy]
# sample_inputs: [sample_input.npy]
# sample_outputs: [sample_output.npy]

weights:
pytorch_state_dict:
authors: [Ashwin Samudre;@bioimage-io]
sha256: e4d3885bccbe41cbf6c1d825f3cd2b707c7021ead5593156007e407a16b27cf2
source: https://zenodo.org/record/3446812/files/unet3d_kinetochores_weights.torch
#sha256: e4d3885bccbe41cbf6c1d825f3cd2b707c7021ead5593156007e407a16b27cf2
source: [best_checkpoint.pytorch]
Binary file added Bioimage-io-scripts/best_checkpoint.pytorch
Binary file not shown.
87 changes: 87 additions & 0 deletions Bioimage-io-scripts/check_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
# File based on https://github.com/mobie/platybrowser-datasets/blob/master/segmentation/cells/UNet3DPlatyCellProbs.model/check_model.py
# specific to Kinetochores use case

import os
import numpy as np
import torch

from pybio.spec.utils.transformers import load_and_resolve_spec
from pybio.spec.utils import get_instance


# TODO this is missing the normalization (preprocessing)
def check_model(path):
""" Convert model weights from format 'pytorch_state_dict' to 'torchscript'.
"""
spec = load_and_resolve_spec(path)

with torch.no_grad():
print("Loading inputs and outputs:")
# load input and expected output data
input_data = np.load(spec.test_inputs[0]).astype('float32')
input_data = torch.from_numpy(input_data)
expected_output_data = np.load(spec.test_outputs[0]).astype(np.float32)
print(input_data.shape)

# instantiate and trace the model
print("Predicting model")
model = get_instance(spec)
state = torch.load(spec.weights['pytorch_state_dict'].source)
model.load_state_dict(state)

# check the scripted model
output_data = model(input_data).numpy()
assert output_data.shape == expected_output_data.shape
assert np.allclose(expected_output_data, output_data)
print("Check passed")


# TODO this is missing the normalization (preprocessing)
def generate_output(path):
spec = load_and_resolve_spec(path)

with torch.no_grad():
print("Loading inputs and outputs:")
# load input and expected output data
input_data = np.load(spec.test_inputs[0]).astype('float32')
input_data = torch.from_numpy(input_data)

# instantiate and trace the model
print("Predicting model")
model = get_instance(spec)
state = torch.load(spec.weights['pytorch_state_dict'].source)
model.load_state_dict(state)

# check the scripted model
output_data = model(input_data).numpy()
assert output_data.shape == input_data.shape
np.save('./test_output.npy', output_data)


def resave_data():
halo = [32, 48, 48]
x = np.load('./test_input.npz')['arr_0']
shape = x.shape[2:]
bb = tuple(slice(sh // 2 - ha, sh // 2 + ha) for sh, ha in zip(shape, halo))
bb = (slice(None), slice(None)) + bb
x = x[bb]
print(x.shape)
np.save('./test_input.npy', x)

y = np.load('./test_output.npz')['arr_0']
y = y[bb]
print(y.shape)
np.save('./test_output.npy', y)


if __name__ == '__main__':
# resave and crop the older test data
# resave_data()

path = os.path.abspath('./UNet3DKinetochores.model.yaml')

# generate expected output again
# generate_output(path)

# check model predictions against the output
check_model(path)
10 changes: 5 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,8 @@ The actual useful IoU metric implemented. Taking threshold of predictions, getti
3. Prepare the complete pipeline starting with the input vol and final segmentation output.

## Another thing to try (let's keep this aside for now):
Harmonic embeddings are there for the 2d data and it works really good with the instance segmentation of biological images. The idea could be to adapt it to 3d data.
Harmonic embeddings network are based on 2d datasets -> 1. quick trial with 2d slices for kinetochores data. (Not really useful)
In fact, let's try -> 2.[Single cell](https://github.com/opnumten/single_cell_segmentation) and 3.[Spatial embeddings](https://github.com/davyneven/SpatialEmbeddings) both on 2d slices for our data.


* Harmonic embeddings are there for the 2d data and it works really good with the instance segmentation of biological images. The idea could be to adapt it to 3d data.
* Harmonic embeddings network are based on 2d datasets -> 1. quick trial with 2d slices for kinetochores data. (Not really useful)
* In fact, let's try -> 2.[Single cell](https://github.com/opnumten/single_cell_segmentation) and 3.[Spatial embeddings](https://github.com/davyneven/SpatialEmbeddings) both on 2d slices for our data.
* Analyze the distance between sources in a pair. If its more or less constant, enforce this property in a loss.
* We've used 1 + 3 (edt + vec_edt) channels for regression task, there's another idea to modify this with distance and angle between sources in a pair. 1 channel for distance between the sources and 3 channels for the angles. The convention could be to use positive distance for the first element/ source of the pair and negative for second element/ source.