From f89d992e3670e5514c56d9ac88eed35cdc5c8d96 Mon Sep 17 00:00:00 2001 From: Martijn van Beers Date: Thu, 28 Feb 2019 12:41:52 +0100 Subject: [PATCH 1/5] Fix typo --- code/vae_notebook.ipynb | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/code/vae_notebook.ipynb b/code/vae_notebook.ipynb index db42ce1..1c822c6 100644 --- a/code/vae_notebook.ipynb +++ b/code/vae_notebook.ipynb @@ -140,8 +140,8 @@ "metadata": {}, "outputs": [], "source": [ - "random_indeces = np.random.randint(valid_set.shape[0], size=5)\n", - "random_pictures = valid_set[random_indeces, :]\n", + "random_indices = np.random.randint(valid_set.shape[0], size=5)\n", + "random_pictures = valid_set[random_indices, :]\n", "width = height = int(sqrt(random_pictures.shape[1]))\n", "plot, axes = plt.subplots(1,5, sharex='col', sharey='row', figsize=(15,3))\n", "for i in range(5):\n", From 4c5a58b6ae62009a18b9b3f1df682417c6dff277 Mon Sep 17 00:00:00 2001 From: Martijn van Beers Date: Thu, 28 Feb 2019 12:42:31 +0100 Subject: [PATCH 2/5] Make formulas look a bit nicer --- code/vae_notebook.ipynb | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/code/vae_notebook.ipynb b/code/vae_notebook.ipynb index 1c822c6..98a99cd 100644 --- a/code/vae_notebook.ipynb +++ b/code/vae_notebook.ipynb @@ -166,9 +166,9 @@ "\n", "In this tutorial, we will model the mist binarised digit data set. Each image is encoded as a 784-dimensional vector. We will model each of these vectors as a product of 784 Bernoullis (of course, there are better models but we want to keep it simple). Our likelihood is thus a product of independent Bernoullis. The resulting model is formally specified as \n", "\n", - "\\begin{align}z \\sim \\mathcal{N}(0,I) && x_i|z \\sim Bernoulli(NN_{\\theta}(z))~~~ i \\in \\{1,2,\\ldots, 784\\} \\ .\\end{align}\n", + "\\begin{align}z \\sim \\mathcal{N}(0,I) && x_i|z \\sim Bernoulli(\\text{NN}_{\\theta}(z))~~~ i \\in \\{1,2,\\ldots, 784\\} \\ .\\end{align}\n", "\n", - "The variational approximation is given by $$q(z|x) = \\mathcal{N}(NN_{\\lambda}(x), NN_{\\lambda}(x)).$$\n", + "The variational approximation is given by $$q(z|x) = \\mathcal{N}(\\text{NN}_{\\lambda}(x), \\text{NN}_{\\lambda}(x)).$$\n", "\n", "Notice that both the Bernoulli likelihood and the Gaussian variational distribution use NNs to compute their parameters. The parameters of the NNs, however, are different ($\\theta$ and $\\lambda$, respectively)." ] From ff1b794b3e859d53f28190330d31fb42020e66c8 Mon Sep 17 00:00:00 2001 From: Martijn van Beers Date: Thu, 28 Feb 2019 12:44:33 +0100 Subject: [PATCH 3/5] Improve logging and image plotting * use logging.getLogger to get a logger to write to instead of calling logging directly. Also set the matplotlib logger to only print ERROR level messages, so we don't get extra output when plotting. * create a generic function to plot (an) image(s) using a specified amount of columns. This allows the user to easily plot more samples --- code/vae_notebook.ipynb | 80 ++++++++++++++++++++++------------------- 1 file changed, 44 insertions(+), 36 deletions(-) diff --git a/code/vae_notebook.ipynb b/code/vae_notebook.ipynb index 98a99cd..50e1083 100644 --- a/code/vae_notebook.ipynb +++ b/code/vae_notebook.ipynb @@ -40,7 +40,7 @@ "from abc import ABC\n", "from typing import List, Tuple, Callable, Optional, Iterable\n", "from matplotlib import cm, pyplot as plt\n", - "from math import sqrt" + "import math" ] }, { @@ -82,7 +82,35 @@ "metadata": {}, "outputs": [], "source": [ - "logging.basicConfig(level=logging.DEBUG, format=\"%(asctime)s [%(levelname)s]: %(message)s\", datefmt=\"%H:%M:%S\")" + "def plot_images(a, columns):\n", + " width = height = int(math.sqrt(a.shape[1]))\n", + " n = a.shape[0]\n", + " rows = math.ceil(n / columns)\n", + " cellsize = min(15 / columns, 5)\n", + " if rows == 1 and columns > n:\n", + " if n == 1:\n", + " plot, axes = plt.subplots(rows, sharex=\"col\", sharey=\"row\", figsize=(cellsize * columns, cellsize * rows))\n", + " else:\n", + " plot, axes = plt.subplots(rows, n, sharex=\"col\", sharey=\"row\", figsize=(cellsize * columns, cellsize * rows))\n", + " else:\n", + " plot, axes = plt.subplots(rows, columns, sharex=\"col\", sharey=\"row\", figsize=(cellsize * columns, cellsize * rows))\n", + " for r in range(rows):\n", + " for c in range(columns):\n", + " if (r * columns + c) >= n:\n", + " break\n", + " if rows == 1:\n", + " if n == 1:\n", + " axes.imshow(np.reshape(a[(r * columns) + c,:], (width, height)), cmap=cm.Greys)\n", + " else:\n", + " axes[c].imshow(np.reshape(a[(r * columns) + c,:], (width, height)), cmap=cm.Greys)\n", + " else:\n", + " axes[r][c].imshow(np.reshape(a[(r * columns) + c,:], (width, height)), cmap=cm.Greys)\n", + " plt.show()\n", + "\n", + "logging.basicConfig(level=logging.DEBUG, format=\"%(asctime)s [%(levelname)s -- %(name)s]: %(message)s\", datefmt=\"%H:%M:%S\")\n", + "logger = logging.getLogger(\"notebook\")\n", + "mpl_logger = logging.getLogger(\"matplotlib\")\n", + "mpl_logger.setLevel(logging.ERROR)" ] }, { @@ -107,13 +135,13 @@ " file_name = \"binary_mnist.{}\".format(data_set)\n", " goal = join(data_dir, file_name)\n", " if exists(goal):\n", - " logging.info(\"Data file {} exists\".format(file_name))\n", + " logger.info(\"Data file {} exists\".format(file_name))\n", " else:\n", - " logging.info(\"Downloading {}\".format(file_name))\n", + " logger.info(\"Downloading {}\".format(file_name))\n", " link = \"http://www.cs.toronto.edu/~larocheh/public/datasets/binarized_mnist/binarized_mnist_{}.amat\".format(\n", " data_set)\n", " urllib.request.urlretrieve(link, goal)\n", - " logging.info(\"Finished\")" + " logger.info(\"Finished\")" ] }, { @@ -130,7 +158,7 @@ "outputs": [], "source": [ "file_name = join(data_dir, \"binary_mnist.{}\".format(VALID_SET))\n", - "logging.info(\"Reading {} into memory\".format(file_name))\n", + "logger.info(\"Reading {} into memory\".format(file_name))\n", "valid_set = np.genfromtxt(file_name)" ] }, @@ -142,11 +170,7 @@ "source": [ "random_indices = np.random.randint(valid_set.shape[0], size=5)\n", "random_pictures = valid_set[random_indices, :]\n", - "width = height = int(sqrt(random_pictures.shape[1]))\n", - "plot, axes = plt.subplots(1,5, sharex='col', sharey='row', figsize=(15,3))\n", - "for i in range(5):\n", - " axes[i].imshow(np.reshape(random_pictures[i,:],(width,height)), cmap=cm.Greys)\n", - "plt.show()" + "plot_images(random_pictures, 5)" ] }, { @@ -676,9 +700,9 @@ "source": [ "mnist = {}\n", "file_name = join(data_dir, \"binary_mnist.{}\".format(TRAIN_SET))\n", - "logging.info(\"Reading {} into memory\".format(file_name))\n", + "logger.info(\"Reading {} into memory\".format(file_name))\n", "mnist[TRAIN_SET] = mx.nd.array(np.genfromtxt(file_name))\n", - "logging.info(\"{} contains {} data points\".format(file_name, mnist[TRAIN_SET].shape[0]))" + "logger.info(\"{} contains {} data points\".format(file_name, mnist[TRAIN_SET].shape[0]))" ] }, { @@ -714,7 +738,7 @@ "source": [ "vae_module = mx.module.Module(vae.train(data=mx.sym.Variable(\"data\"), label=mx.sym.Variable(\"label\")),\n", " data_names=[train_iter.provide_data[0][0]],\n", - " label_names=[train_iter.provide_label[0][0]], context=ctx, logger=logging)" + " label_names=[train_iter.provide_label[0][0]], context=ctx, logger=logger)" ] }, { @@ -815,15 +839,8 @@ "# transform output into numpy arrays\n", "digits = digits.asnumpy()\n", "latent_values = latent_values.asnumpy()\n", - "rows = int(num_samples / 5)\n", "\n", - "plot, axes = plt.subplots(rows, 5, sharex='col', sharey='row', figsize=(30,6))\n", - "sample=0\n", - "for row in range(rows):\n", - " for col in range(5):\n", - " axes[row][col].imshow(np.reshape(digits[sample,:],(width,height)), cmap=cm.Greys)\n", - " sample += 1\n", - "plt.show()" + "plot_images(digits, 5)" ] }, { @@ -840,7 +857,7 @@ "outputs": [], "source": [ "file_name = join(data_dir, \"binary_mnist.{}\".format(TEST_SET))\n", - "logging.info(\"Reading {} into memory\".format(file_name))\n", + "logger.info(\"Reading {} into memory\".format(file_name))\n", "test_set = np.genfromtxt(file_name)" ] }, @@ -857,11 +874,9 @@ "metadata": {}, "outputs": [], "source": [ - "random_idx = np.random.randint(test_set.shape[0])\n", + "random_idx = np.random.randint(test_set.shape[0], size=1)\n", "random_picture = test_set[random_idx, :]\n", - "plot, canvas = plt.subplots(1, figsize=(5,5))\n", - "canvas.imshow(np.reshape(random_picture,(width,height)), cmap=cm.Greys)\n", - "plt.show()" + "plot_images(random_picture, 1)" ] }, { @@ -883,7 +898,7 @@ "# We group the outputs of generate_reconstructions to be able to process them as a single symbol\n", "reconstructions = mx.sym.Group(vae.generate_reconstructions(mx.sym.Variable(\"random_digit\"), num_samples))\n", "# construct an executor by binding the parameters to the learned values\n", - "params[\"random_digit\"] = mx.nd.array(random_picture.reshape((1,height*width)))\n", + "params[\"random_digit\"] = mx.nd.array(random_picture)\n", "reconstruction_exec = reconstructions.bind(ctx=ctx, args=params)\n", "\n", "# run the computation\n", @@ -894,14 +909,7 @@ "latent_values = latent_values.asnumpy()\n", "\n", "# plot the reconstructed digits\n", - "rows = int(num_samples / 5)\n", - "plot, axes = plt.subplots(rows, 5, sharex='col', sharey='row', figsize=(30,6))\n", - "sample=0\n", - "for row in range(rows):\n", - " for col in range(5):\n", - " axes[row][col].imshow(np.reshape(digits[sample,:],(width,height)), cmap=cm.Greys)\n", - " sample += 1\n", - "plt.show()" + "plot_images(digits, 5)" ] }, { From a588ab9a7bebb835d48f05118d5f2e5d9a2c48f2 Mon Sep 17 00:00:00 2001 From: Martijn van Beers Date: Thu, 28 Feb 2019 14:53:26 +0100 Subject: [PATCH 4/5] Pass in classes, not strings for encoder/decoder Passing in the class instead of a string makes construct_vae generic enough that it doesn't need changes when you want to play with different implementations. --- code/vae_notebook.ipynb | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/code/vae_notebook.ipynb b/code/vae_notebook.ipynb index 50e1083..f0f594d 100644 --- a/code/vae_notebook.ipynb +++ b/code/vae_notebook.ipynb @@ -606,8 +606,8 @@ "metadata": {}, "outputs": [], "source": [ - "def construct_vae(latent_type: str,\n", - " likelihood: str,\n", + "def construct_vae(latent_type: InferenceNetwork,\n", + " likelihood: Generator,\n", " generator_layer_sizes: List[int],\n", " infer_layer_sizes: List[int],\n", " latent_variable_size: int,\n", @@ -627,18 +627,18 @@ " :param infer_act_type: Activation function for inference network hidden layers.\n", " :return: A variational autoencoder.\n", " \"\"\"\n", - " if likelihood == \"bernoulliProd\":\n", - " generator = ProductOfBernoullisGenerator(data_dims=data_dims, layer_sizes=generator_layer_sizes,\n", + " if issubclass(likelihood, Generator):\n", + " generator = likelihood(data_dims=data_dims, layer_sizes=generator_layer_sizes,\n", " act_type=generator_act_type)\n", " else:\n", - " raise Exception(\"{} is an invalid likelihood type.\".format(likelihood))\n", + " raise Exception(\"{} is an invalid likelihood type. It should be a subclass of Generator\".format(likelihood))\n", "\n", - " if latent_type == \"gaussian\":\n", - " inference_net = GaussianInferenceNetwork(latent_variable_size=latent_variable_size,\n", + " if issubclass(latent_type, InferenceNetwork):\n", + " inference_net = latent_type(latent_variable_size=latent_variable_size,\n", " layer_sizes=infer_layer_sizes,\n", " act_type=infer_act_type)\n", " else:\n", - " raise Exception(\"{} is an invalid latent variable type.\".format(latent_type))\n", + " raise Exception(\"{} is an invalid latent variable type. It should be a subclass of InferenceNetwork\".format(latent_type))\n", " \n", " return GaussianVAE(generator=generator, inference_net=inference_net, kl_divergence=diagonal_gaussian_kl)" ] @@ -658,7 +658,7 @@ "metadata": {}, "outputs": [], "source": [ - "vae = construct_vae(latent_type=\"gaussian\", likelihood=\"bernoulliProd\", generator_layer_sizes=[200,500],\n", + "vae = construct_vae(latent_type=GaussianInferenceNetwork, likelihood=ProductOfBernoullisGenerator, generator_layer_sizes=[200,500],\n", " infer_layer_sizes=[500,200], latent_variable_size=200, data_dims=784, generator_act_type=\"tanh\",\n", " infer_act_type=\"tanh\")" ] From adf67802e81fe9805723b93f5a2d6426df2b71b4 Mon Sep 17 00:00:00 2001 From: Martijn van Beers Date: Thu, 28 Feb 2019 15:10:13 +0100 Subject: [PATCH 5/5] Remove unused kl_divergence parameter the kl_divergence parameter in the GaussianVAE was never used, and it seems diagonal_gaussian_kl should be private to GaussianInferenceNetwork, so remove it. --- code/vae_notebook.ipynb | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/code/vae_notebook.ipynb b/code/vae_notebook.ipynb index f0f594d..03ad849 100644 --- a/code/vae_notebook.ipynb +++ b/code/vae_notebook.ipynb @@ -38,7 +38,7 @@ "import os, logging, sys\n", "from os.path import join, exists\n", "from abc import ABC\n", - "from typing import List, Tuple, Callable, Optional, Iterable\n", + "from typing import List, Tuple, Optional, Iterable\n", "from matplotlib import cm, pyplot as plt\n", "import math" ] @@ -549,11 +549,8 @@ "\n", " def __init__(self,\n", " generator: Generator,\n", - " inference_net: GaussianInferenceNetwork,\n", - " kl_divergence: Callable) -> None:\n", - " self.generator = generator\n", - " self.inference_net = inference_net\n", - " self.kl_divergence = kl_divergence\n", + " inference_net: GaussianInferenceNetwork) -> None:\n", + " super(GaussianVAE, self).__init__(generator, inference_net)\n", "\n", " def train(self, data: mx.sym.Symbol, label: mx.sym.Symbol) -> mx.sym.Symbol:\n", " \"\"\"\n", @@ -640,7 +637,7 @@ " else:\n", " raise Exception(\"{} is an invalid latent variable type. It should be a subclass of InferenceNetwork\".format(latent_type))\n", " \n", - " return GaussianVAE(generator=generator, inference_net=inference_net, kl_divergence=diagonal_gaussian_kl)" + " return GaussianVAE(generator=generator, inference_net=inference_net)" ] }, {