Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 61 additions & 56 deletions code/vae_notebook.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,9 @@
"import os, logging, sys\n",
"from os.path import join, exists\n",
"from abc import ABC\n",
"from typing import List, Tuple, Callable, Optional, Iterable\n",
"from typing import List, Tuple, Optional, Iterable\n",
"from matplotlib import cm, pyplot as plt\n",
"from math import sqrt"
"import math"
]
},
{
Expand Down Expand Up @@ -82,7 +82,35 @@
"metadata": {},
"outputs": [],
"source": [
"logging.basicConfig(level=logging.DEBUG, format=\"%(asctime)s [%(levelname)s]: %(message)s\", datefmt=\"%H:%M:%S\")"
"def plot_images(a, columns):\n",
" width = height = int(math.sqrt(a.shape[1]))\n",
" n = a.shape[0]\n",
" rows = math.ceil(n / columns)\n",
" cellsize = min(15 / columns, 5)\n",
" if rows == 1 and columns > n:\n",
" if n == 1:\n",
" plot, axes = plt.subplots(rows, sharex=\"col\", sharey=\"row\", figsize=(cellsize * columns, cellsize * rows))\n",
" else:\n",
" plot, axes = plt.subplots(rows, n, sharex=\"col\", sharey=\"row\", figsize=(cellsize * columns, cellsize * rows))\n",
" else:\n",
" plot, axes = plt.subplots(rows, columns, sharex=\"col\", sharey=\"row\", figsize=(cellsize * columns, cellsize * rows))\n",
" for r in range(rows):\n",
" for c in range(columns):\n",
" if (r * columns + c) >= n:\n",
" break\n",
" if rows == 1:\n",
" if n == 1:\n",
" axes.imshow(np.reshape(a[(r * columns) + c,:], (width, height)), cmap=cm.Greys)\n",
" else:\n",
" axes[c].imshow(np.reshape(a[(r * columns) + c,:], (width, height)), cmap=cm.Greys)\n",
" else:\n",
" axes[r][c].imshow(np.reshape(a[(r * columns) + c,:], (width, height)), cmap=cm.Greys)\n",
" plt.show()\n",
"\n",
"logging.basicConfig(level=logging.DEBUG, format=\"%(asctime)s [%(levelname)s -- %(name)s]: %(message)s\", datefmt=\"%H:%M:%S\")\n",
"logger = logging.getLogger(\"notebook\")\n",
"mpl_logger = logging.getLogger(\"matplotlib\")\n",
"mpl_logger.setLevel(logging.ERROR)"
]
},
{
Expand All @@ -107,13 +135,13 @@
" file_name = \"binary_mnist.{}\".format(data_set)\n",
" goal = join(data_dir, file_name)\n",
" if exists(goal):\n",
" logging.info(\"Data file {} exists\".format(file_name))\n",
" logger.info(\"Data file {} exists\".format(file_name))\n",
" else:\n",
" logging.info(\"Downloading {}\".format(file_name))\n",
" logger.info(\"Downloading {}\".format(file_name))\n",
" link = \"http://www.cs.toronto.edu/~larocheh/public/datasets/binarized_mnist/binarized_mnist_{}.amat\".format(\n",
" data_set)\n",
" urllib.request.urlretrieve(link, goal)\n",
" logging.info(\"Finished\")"
" logger.info(\"Finished\")"
]
},
{
Expand All @@ -130,7 +158,7 @@
"outputs": [],
"source": [
"file_name = join(data_dir, \"binary_mnist.{}\".format(VALID_SET))\n",
"logging.info(\"Reading {} into memory\".format(file_name))\n",
"logger.info(\"Reading {} into memory\".format(file_name))\n",
"valid_set = np.genfromtxt(file_name)"
]
},
Expand All @@ -140,13 +168,9 @@
"metadata": {},
"outputs": [],
"source": [
"random_indeces = np.random.randint(valid_set.shape[0], size=5)\n",
"random_pictures = valid_set[random_indeces, :]\n",
"width = height = int(sqrt(random_pictures.shape[1]))\n",
"plot, axes = plt.subplots(1,5, sharex='col', sharey='row', figsize=(15,3))\n",
"for i in range(5):\n",
" axes[i].imshow(np.reshape(random_pictures[i,:],(width,height)), cmap=cm.Greys)\n",
"plt.show()"
"random_indices = np.random.randint(valid_set.shape[0], size=5)\n",
"random_pictures = valid_set[random_indices, :]\n",
"plot_images(random_pictures, 5)"
]
},
{
Expand All @@ -166,9 +190,9 @@
"\n",
"In this tutorial, we will model the mist binarised digit data set. Each image is encoded as a 784-dimensional vector. We will model each of these vectors as a product of 784 Bernoullis (of course, there are better models but we want to keep it simple). Our likelihood is thus a product of independent Bernoullis. The resulting model is formally specified as \n",
"\n",
"\\begin{align}z \\sim \\mathcal{N}(0,I) && x_i|z \\sim Bernoulli(NN_{\\theta}(z))~~~ i \\in \\{1,2,\\ldots, 784\\} \\ .\\end{align}\n",
"\\begin{align}z \\sim \\mathcal{N}(0,I) && x_i|z \\sim Bernoulli(\\text{NN}_{\\theta}(z))~~~ i \\in \\{1,2,\\ldots, 784\\} \\ .\\end{align}\n",
"\n",
"The variational approximation is given by $$q(z|x) = \\mathcal{N}(NN_{\\lambda}(x), NN_{\\lambda}(x)).$$\n",
"The variational approximation is given by $$q(z|x) = \\mathcal{N}(\\text{NN}_{\\lambda}(x), \\text{NN}_{\\lambda}(x)).$$\n",
"\n",
"Notice that both the Bernoulli likelihood and the Gaussian variational distribution use NNs to compute their parameters. The parameters of the NNs, however, are different ($\\theta$ and $\\lambda$, respectively)."
]
Expand Down Expand Up @@ -525,11 +549,8 @@
"\n",
" def __init__(self,\n",
" generator: Generator,\n",
" inference_net: GaussianInferenceNetwork,\n",
" kl_divergence: Callable) -> None:\n",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are you removing the KL divergence. It is needed to properly construct the ELBO.

I have thought about this a lot before and one alternative is to rewrite the ELBO as E[p(x,z)] + H(q(z)). Then we could provide the entropy as a method of the inference net. However, it's out of whack with the slides in that case. That's why for the time being I would just pass in the kl_divergence as an argument.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I removed it because it wasn't actually being used (as far as I can see). The GaussianInferenceNetwork calls diagonal_gaussian_kl() directly. I've been implementing inference networks with different distributions with Wilker, which require different ways to calculate the KL divergence, it seemed to me that the KL divergence was conceptually linked more to the inference network than to the VAE as a whole, which is why I opted for removing the parameter rather than making the GaussianInferenceNetwork not hardcode the function.

" self.generator = generator\n",
" self.inference_net = inference_net\n",
" self.kl_divergence = kl_divergence\n",
" inference_net: GaussianInferenceNetwork) -> None:\n",
" super(GaussianVAE, self).__init__(generator, inference_net)\n",
"\n",
" def train(self, data: mx.sym.Symbol, label: mx.sym.Symbol) -> mx.sym.Symbol:\n",
" \"\"\"\n",
Expand Down Expand Up @@ -582,8 +603,8 @@
"metadata": {},
"outputs": [],
"source": [
"def construct_vae(latent_type: str,\n",
" likelihood: str,\n",
"def construct_vae(latent_type: InferenceNetwork,\n",
" likelihood: Generator,\n",
" generator_layer_sizes: List[int],\n",
" infer_layer_sizes: List[int],\n",
" latent_variable_size: int,\n",
Expand All @@ -603,20 +624,20 @@
" :param infer_act_type: Activation function for inference network hidden layers.\n",
" :return: A variational autoencoder.\n",
" \"\"\"\n",
" if likelihood == \"bernoulliProd\":\n",
" generator = ProductOfBernoullisGenerator(data_dims=data_dims, layer_sizes=generator_layer_sizes,\n",
" if issubclass(likelihood, Generator):\n",
" generator = likelihood(data_dims=data_dims, layer_sizes=generator_layer_sizes,\n",
" act_type=generator_act_type)\n",
" else:\n",
" raise Exception(\"{} is an invalid likelihood type.\".format(likelihood))\n",
" raise Exception(\"{} is an invalid likelihood type. It should be a subclass of Generator\".format(likelihood))\n",
"\n",
" if latent_type == \"gaussian\":\n",
" inference_net = GaussianInferenceNetwork(latent_variable_size=latent_variable_size,\n",
" if issubclass(latent_type, InferenceNetwork):\n",
" inference_net = latent_type(latent_variable_size=latent_variable_size,\n",
" layer_sizes=infer_layer_sizes,\n",
" act_type=infer_act_type)\n",
" else:\n",
" raise Exception(\"{} is an invalid latent variable type.\".format(latent_type))\n",
" raise Exception(\"{} is an invalid latent variable type. It should be a subclass of InferenceNetwork\".format(latent_type))\n",
" \n",
" return GaussianVAE(generator=generator, inference_net=inference_net, kl_divergence=diagonal_gaussian_kl)"
" return GaussianVAE(generator=generator, inference_net=inference_net)"
]
},
{
Expand All @@ -634,7 +655,7 @@
"metadata": {},
"outputs": [],
"source": [
"vae = construct_vae(latent_type=\"gaussian\", likelihood=\"bernoulliProd\", generator_layer_sizes=[200,500],\n",
"vae = construct_vae(latent_type=GaussianInferenceNetwork, likelihood=ProductOfBernoullisGenerator, generator_layer_sizes=[200,500],\n",
" infer_layer_sizes=[500,200], latent_variable_size=200, data_dims=784, generator_act_type=\"tanh\",\n",
" infer_act_type=\"tanh\")"
]
Expand Down Expand Up @@ -676,9 +697,9 @@
"source": [
"mnist = {}\n",
"file_name = join(data_dir, \"binary_mnist.{}\".format(TRAIN_SET))\n",
"logging.info(\"Reading {} into memory\".format(file_name))\n",
"logger.info(\"Reading {} into memory\".format(file_name))\n",
"mnist[TRAIN_SET] = mx.nd.array(np.genfromtxt(file_name))\n",
"logging.info(\"{} contains {} data points\".format(file_name, mnist[TRAIN_SET].shape[0]))"
"logger.info(\"{} contains {} data points\".format(file_name, mnist[TRAIN_SET].shape[0]))"
]
},
{
Expand Down Expand Up @@ -714,7 +735,7 @@
"source": [
"vae_module = mx.module.Module(vae.train(data=mx.sym.Variable(\"data\"), label=mx.sym.Variable(\"label\")),\n",
" data_names=[train_iter.provide_data[0][0]],\n",
" label_names=[train_iter.provide_label[0][0]], context=ctx, logger=logging)"
" label_names=[train_iter.provide_label[0][0]], context=ctx, logger=logger)"
]
},
{
Expand Down Expand Up @@ -815,15 +836,8 @@
"# transform output into numpy arrays\n",
"digits = digits.asnumpy()\n",
"latent_values = latent_values.asnumpy()\n",
"rows = int(num_samples / 5)\n",
"\n",
"plot, axes = plt.subplots(rows, 5, sharex='col', sharey='row', figsize=(30,6))\n",
"sample=0\n",
"for row in range(rows):\n",
" for col in range(5):\n",
" axes[row][col].imshow(np.reshape(digits[sample,:],(width,height)), cmap=cm.Greys)\n",
" sample += 1\n",
"plt.show()"
"plot_images(digits, 5)"
]
},
{
Expand All @@ -840,7 +854,7 @@
"outputs": [],
"source": [
"file_name = join(data_dir, \"binary_mnist.{}\".format(TEST_SET))\n",
"logging.info(\"Reading {} into memory\".format(file_name))\n",
"logger.info(\"Reading {} into memory\".format(file_name))\n",
"test_set = np.genfromtxt(file_name)"
]
},
Expand All @@ -857,11 +871,9 @@
"metadata": {},
"outputs": [],
"source": [
"random_idx = np.random.randint(test_set.shape[0])\n",
"random_idx = np.random.randint(test_set.shape[0], size=1)\n",
"random_picture = test_set[random_idx, :]\n",
"plot, canvas = plt.subplots(1, figsize=(5,5))\n",
"canvas.imshow(np.reshape(random_picture,(width,height)), cmap=cm.Greys)\n",
"plt.show()"
"plot_images(random_picture, 1)"
]
},
{
Expand All @@ -883,7 +895,7 @@
"# We group the outputs of generate_reconstructions to be able to process them as a single symbol\n",
"reconstructions = mx.sym.Group(vae.generate_reconstructions(mx.sym.Variable(\"random_digit\"), num_samples))\n",
"# construct an executor by binding the parameters to the learned values\n",
"params[\"random_digit\"] = mx.nd.array(random_picture.reshape((1,height*width)))\n",
"params[\"random_digit\"] = mx.nd.array(random_picture)\n",
"reconstruction_exec = reconstructions.bind(ctx=ctx, args=params)\n",
"\n",
"# run the computation\n",
Expand All @@ -894,14 +906,7 @@
"latent_values = latent_values.asnumpy()\n",
"\n",
"# plot the reconstructed digits\n",
"rows = int(num_samples / 5)\n",
"plot, axes = plt.subplots(rows, 5, sharex='col', sharey='row', figsize=(30,6))\n",
"sample=0\n",
"for row in range(rows):\n",
" for col in range(5):\n",
" axes[row][col].imshow(np.reshape(digits[sample,:],(width,height)), cmap=cm.Greys)\n",
" sample += 1\n",
"plt.show()"
"plot_images(digits, 5)"
]
},
{
Expand Down