From b118c555d431773fbe990e53eb98dd1c2126c82e Mon Sep 17 00:00:00 2001 From: Austin Roose Date: Mon, 13 Feb 2023 01:09:44 +0100 Subject: [PATCH] tests for encoder and decoder return value dimensions --- model/__init__.py | 0 model/test_encoder.py | 22 ++++++++++++++++++++++ 2 files changed, 22 insertions(+) create mode 100644 model/__init__.py create mode 100644 model/test_encoder.py diff --git a/model/__init__.py b/model/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/model/test_encoder.py b/model/test_encoder.py new file mode 100644 index 0000000..83e22c1 --- /dev/null +++ b/model/test_encoder.py @@ -0,0 +1,22 @@ +import torch +from UNet import Encoder, Decoder + + +def test_encoder_valid_shape(): + encoder = Encoder(num_input_channels=1, base_channel_size=32, latent_dim=256) + x = torch.randn(10000, 1, 28, 28) + res_shape = encoder(x).shape + assert res_shape[0] == 10000 and res_shape[1] == 256, "Correct encoding dimensions" + + +def test_decoder_valid_shape(): + decoder = Decoder(num_input_channels=1, base_channel_size=32, latent_dim=256) + x = torch.randn(1000, 256) + res_shape = decoder(x).shape + assert res_shape[0] == 1000 and res_shape[1] == 1, "Correct decoding dimensions" + + +if __name__ == "__main__": + test_encoder_valid_shape() + test_decoder_valid_shape() + print("Everything passed") \ No newline at end of file