diff --git a/Bioimage-io-scripts/UNet3DKinetochores.model.yaml b/Bioimage-io-scripts/UNet3DKinetochores.model.yaml index 62a4607..e41cf64 100644 --- a/Bioimage-io-scripts/UNet3DKinetochores.model.yaml +++ b/Bioimage-io-scripts/UNet3DKinetochores.model.yaml @@ -25,10 +25,10 @@ timestamp: 2019-12-11T12:22:32Z # ISO 8601 inputs: - name: raw description: raw input - axes: bczyx # letters of axes in btczyx + axes: czyx # letters of axes in btczyx data_type: float32 data_range: [-inf, inf] - shape: [1, 1, 48, 128, 128] + shape: [1, 48, 128, 128] preprocessing: # list of preprocessing steps - name: zero_mean_unit_variance # name of preprocessing step kwargs: @@ -45,14 +45,14 @@ inputs: outputs: - name: probability description: probability in [0,1] - axes: bczyx + axes: czyx data_type: float32 data_range: [-inf, inf] - halo: [0, 0, 32, 48, 48] + halo: [0, 32, 48, 48] shape: reference_input: raw - scale: [1, 1, 1, 1, 1] - offset: [0, 0, 0, 0, 0] + scale: [1, 1, 1, 1] + offset: [0, 0, 0, 0] language: python framework: pytorch @@ -63,11 +63,11 @@ dependencies: conda:../environment.yaml test_inputs: [test_input.npy] test_outputs: [test_output.npy] -sample_inputs: [sample_input.npy] -sample_outputs: [sample_output.npy] +# sample_inputs: [sample_input.npy] +# sample_outputs: [sample_output.npy] weights: pytorch_state_dict: authors: [Ashwin Samudre;@bioimage-io] - sha256: e4d3885bccbe41cbf6c1d825f3cd2b707c7021ead5593156007e407a16b27cf2 - source: https://zenodo.org/record/3446812/files/unet3d_kinetochores_weights.torch \ No newline at end of file + #sha256: e4d3885bccbe41cbf6c1d825f3cd2b707c7021ead5593156007e407a16b27cf2 + source: [best_checkpoint.pytorch] \ No newline at end of file diff --git a/Bioimage-io-scripts/best_checkpoint.pytorch b/Bioimage-io-scripts/best_checkpoint.pytorch new file mode 100644 index 0000000..1fb076d Binary files /dev/null and b/Bioimage-io-scripts/best_checkpoint.pytorch differ diff --git a/Bioimage-io-scripts/check_model.py b/Bioimage-io-scripts/check_model.py new file mode 100644 index 0000000..e35a385 --- /dev/null +++ b/Bioimage-io-scripts/check_model.py @@ -0,0 +1,87 @@ +# File based on https://github.com/mobie/platybrowser-datasets/blob/master/segmentation/cells/UNet3DPlatyCellProbs.model/check_model.py +# specific to Kinetochores use case + +import os +import numpy as np +import torch + +from pybio.spec.utils.transformers import load_and_resolve_spec +from pybio.spec.utils import get_instance + + +# TODO this is missing the normalization (preprocessing) +def check_model(path): + """ Convert model weights from format 'pytorch_state_dict' to 'torchscript'. + """ + spec = load_and_resolve_spec(path) + + with torch.no_grad(): + print("Loading inputs and outputs:") + # load input and expected output data + input_data = np.load(spec.test_inputs[0]).astype('float32') + input_data = torch.from_numpy(input_data) + expected_output_data = np.load(spec.test_outputs[0]).astype(np.float32) + print(input_data.shape) + + # instantiate and trace the model + print("Predicting model") + model = get_instance(spec) + state = torch.load(spec.weights['pytorch_state_dict'].source) + model.load_state_dict(state) + + # check the scripted model + output_data = model(input_data).numpy() + assert output_data.shape == expected_output_data.shape + assert np.allclose(expected_output_data, output_data) + print("Check passed") + + +# TODO this is missing the normalization (preprocessing) +def generate_output(path): + spec = load_and_resolve_spec(path) + + with torch.no_grad(): + print("Loading inputs and outputs:") + # load input and expected output data + input_data = np.load(spec.test_inputs[0]).astype('float32') + input_data = torch.from_numpy(input_data) + + # instantiate and trace the model + print("Predicting model") + model = get_instance(spec) + state = torch.load(spec.weights['pytorch_state_dict'].source) + model.load_state_dict(state) + + # check the scripted model + output_data = model(input_data).numpy() + assert output_data.shape == input_data.shape + np.save('./test_output.npy', output_data) + + +def resave_data(): + halo = [32, 48, 48] + x = np.load('./test_input.npz')['arr_0'] + shape = x.shape[2:] + bb = tuple(slice(sh // 2 - ha, sh // 2 + ha) for sh, ha in zip(shape, halo)) + bb = (slice(None), slice(None)) + bb + x = x[bb] + print(x.shape) + np.save('./test_input.npy', x) + + y = np.load('./test_output.npz')['arr_0'] + y = y[bb] + print(y.shape) + np.save('./test_output.npy', y) + + +if __name__ == '__main__': + # resave and crop the older test data + # resave_data() + + path = os.path.abspath('./UNet3DKinetochores.model.yaml') + + # generate expected output again + # generate_output(path) + + # check model predictions against the output + check_model(path) \ No newline at end of file diff --git a/README.md b/README.md index bf857f0..0c5894d 100644 --- a/README.md +++ b/README.md @@ -122,8 +122,8 @@ The actual useful IoU metric implemented. Taking threshold of predictions, getti 3. Prepare the complete pipeline starting with the input vol and final segmentation output. ## Another thing to try (let's keep this aside for now): -Harmonic embeddings are there for the 2d data and it works really good with the instance segmentation of biological images. The idea could be to adapt it to 3d data. -Harmonic embeddings network are based on 2d datasets -> 1. quick trial with 2d slices for kinetochores data. (Not really useful) -In fact, let's try -> 2.[Single cell](https://github.com/opnumten/single_cell_segmentation) and 3.[Spatial embeddings](https://github.com/davyneven/SpatialEmbeddings) both on 2d slices for our data. - - +* Harmonic embeddings are there for the 2d data and it works really good with the instance segmentation of biological images. The idea could be to adapt it to 3d data. +* Harmonic embeddings network are based on 2d datasets -> 1. quick trial with 2d slices for kinetochores data. (Not really useful) +* In fact, let's try -> 2.[Single cell](https://github.com/opnumten/single_cell_segmentation) and 3.[Spatial embeddings](https://github.com/davyneven/SpatialEmbeddings) both on 2d slices for our data. +* Analyze the distance between sources in a pair. If its more or less constant, enforce this property in a loss. +* We've used 1 + 3 (edt + vec_edt) channels for regression task, there's another idea to modify this with distance and angle between sources in a pair. 1 channel for distance between the sources and 3 channels for the angles. The convention could be to use positive distance for the first element/ source of the pair and negative for second element/ source. \ No newline at end of file