diff --git a/lsd/tutorial/example_nets/fib25/acrlsd/mknet.py b/lsd/tutorial/example_nets/fib25/acrlsd/mknet.py index 2215fcd..191f087 100644 --- a/lsd/tutorial/example_nets/fib25/acrlsd/mknet.py +++ b/lsd/tutorial/example_nets/fib25/acrlsd/mknet.py @@ -1,7 +1,7 @@ import json -import mala +from funlib.learn.tensorflow import models +from funlib.learn.tensorflow.models.unet import crop import tensorflow as tf -from mala.networks.unet import crop_zyx def create_auto(input_shape, output_shape, name): @@ -12,20 +12,20 @@ def create_auto(input_shape, output_shape, name): raw = tf.placeholder(tf.float32, shape=input_shape) raw_batched = tf.reshape(raw, (1, 1) + input_shape) - unet, _, _ = mala.networks.unet( + unet, _, _ = models.unet( raw_batched, 12, 6, [[2,2,2],[2,2,2],[3,3,3]]) - embedding_batched, _ = mala.networks.conv_pass( + embedding_batched, _ = models.conv_pass( unet, kernel_sizes=[1], num_fmaps=10, activation='sigmoid', name='embedding') - - embedding_batched = crop_zyx(embedding_batched, (1, 10) + output_shape) + + embedding_batched = crop(embedding_batched, (1, 10) + output_shape) embedding = tf.reshape(embedding_batched, (10,) + output_shape) print("input shape : %s"%(input_shape,)) @@ -50,7 +50,7 @@ def create_affs(input_shape, intermediate_shape, expected_output_shape, name): raw = tf.placeholder(tf.float32, shape=input_shape) raw_batched = tf.reshape(raw, (1, 1) + input_shape) raw_in = tf.reshape(raw_batched, input_shape) - raw_batched = crop_zyx(raw_batched, (1, 1) + intermediate_shape) + raw_batched = crop(raw_batched, (1, 1) + intermediate_shape) raw_cropped = tf.reshape(raw_batched, intermediate_shape) pretrained_lsd = tf.placeholder(tf.float32, shape=(10,) + intermediate_shape) @@ -58,13 +58,13 @@ def create_affs(input_shape, intermediate_shape, expected_output_shape, name): concat_input = tf.concat([raw_batched, pretrained_lsd_batched], axis=1) - unet, _, _ = mala.networks.unet( + unet, _, _ = models.unet( concat_input, 12, 6, [[2,2,2],[2,2,2],[3,3,3]]) - affs_batched, _ = mala.networks.conv_pass( + affs_batched, _ = models.conv_pass( unet, kernel_sizes=[1], num_fmaps=3, @@ -129,15 +129,15 @@ def create_config(input_shape, output_shape, name): if __name__ == "__main__": - train_input_shape = (304, 304, 304) + train_input_shape = (328, 328, 328) train_intermediate_shape = (196, 196, 196) - train_output_shape = (92, 92, 92) + train_output_shape = (72, 72, 72) create_auto(train_input_shape, train_intermediate_shape, 'train_auto_net') create_affs(train_input_shape, train_intermediate_shape, train_output_shape, 'train_net') test_input_shape = (364, 364, 364) - test_output_shape = (260, 260, 260) + test_output_shape = (240, 240, 240) create_affs(test_input_shape, test_input_shape, test_output_shape, 'test_net')