{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "<img width=\"800px\" src=\"../fidle/img/00-Fidle-header-01.svg\"></img>\n", "\n", "# <!-- TITLE --> [VAE1] - First VAE, with a small dataset (MNIST)\n", "<!-- DESC --> Construction and training of a VAE with a latent space of small dimension.\n", "\n", "<!-- AUTHOR : Jean-Luc Parouty (CNRS/SIMaP) -->\n", "\n", "## Objectives :\n", " - Understanding and implementing a **variational autoencoder** neurals network (VAE)\n", " - Understanding a more **advanced programming model**\n", "\n", "The calculation needs being important, it is preferable to use a very simple dataset such as MNIST to start with.\n", "\n", "## What we're going to do :\n", "\n", " - Defining a VAE model\n", " - Build the model\n", " - Train it\n", " - Follow the learning process with Tensorboard\n", "\n", "## Acknowledgements :\n", "Thanks to **François Chollet** who is at the base of this example. \n", "François Chollet is not only the author of Keras and a great guru, he is also a sorcerer ;-) \n", "See : https://keras.io/examples/generative/vae\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Step 1 - Init python stuff" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "from skimage import io\n", "\n", "import tensorflow as tf\n", "from tensorflow import keras\n", "from tensorflow.keras import layers\n", "from tensorflow.keras.callbacks import ModelCheckpoint, TensorBoard\n", "\n", "import os,sys,json,time,datetime\n", "from IPython.display import display,Image,Markdown,HTML\n", "\n", "from modules.VAE import VAE, Sampling\n", "from modules.loader_MNIST import Loader_MNIST\n", "from modules.callbacks import ImagesCallback, BestModelCallback\n", "\n", "sys.path.append('..')\n", "import fidle.pwk as pwk\n", "\n", "run_dir = './run/MNIST.001' # Output directory\n", "datasets_dir = pwk.init('VAE1', run_dir)\n", "\n", "VAE.about()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Step 2 - Parameters\n", "Uncomment the right lines according to what you want." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# ---- Smart tests\n", "#\n", "latent_dim = 2\n", "r_loss_factor = 0.994\n", "scale = 0.1\n", "batch_size = 64\n", "epochs = 10\n", "\n", "# ---- Full run (1'30 on a V100)\n", "#\n", "# latent_dim = 2\n", "# r_loss_factor = 0.994\n", "# scale = 1.\n", "# batch_size = 64\n", "# epochs = 10" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Override parameters (batch mode) - Just forget this cell" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "pwk.override('scale', 'latent_dim', 'r_loss_factor', 'batch_size', 'epochs')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Step 3 - Prepare data\n", "### 3.1 - Get it" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "x_train, _ = Loader_MNIST.get()\n", "np.random.shuffle(x_train)\n", "nb_images = int(len(x_train)*scale)\n", "x_train = x_train[:nb_images]\n", "print('\\nTrain shape after rescale : ',x_train.shape)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 3.2 - Have a look" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "pwk.plot_images(x_train[:5], None, indices='all', columns=5, x_size=3,y_size=2, save_as='01-original')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Step 4 - Build model" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Encoder" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "inputs = keras.Input(shape=(28, 28, 1))\n", "x = layers.Conv2D(32, 3, strides=1, padding=\"same\", activation=\"relu\")(inputs)\n", "x = layers.Conv2D(64, 3, strides=2, padding=\"same\", activation=\"relu\")(x)\n", "x = layers.Conv2D(64, 3, strides=2, padding=\"same\", activation=\"relu\")(x)\n", "x = layers.Conv2D(64, 3, strides=1, padding=\"same\", activation=\"relu\")(x)\n", "x = layers.Flatten()(x)\n", "x = layers.Dense(16, activation=\"relu\")(x)\n", "\n", "z_mean = layers.Dense(latent_dim, name=\"z_mean\")(x)\n", "z_log_var = layers.Dense(latent_dim, name=\"z_log_var\")(x)\n", "z = Sampling()([z_mean, z_log_var])\n", "\n", "encoder = keras.Model(inputs, [z_mean, z_log_var, z], name=\"encoder\")\n", "encoder.compile()\n", "# encoder.summary()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Decoder" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "inputs = keras.Input(shape=(latent_dim,))\n", "x = layers.Dense(7 * 7 * 64, activation=\"relu\")(inputs)\n", "x = layers.Reshape((7, 7, 64))(x)\n", "x = layers.Conv2DTranspose(64, 3, strides=1, padding=\"same\", activation=\"relu\")(x)\n", "x = layers.Conv2DTranspose(64, 3, strides=2, padding=\"same\", activation=\"relu\")(x)\n", "x = layers.Conv2DTranspose(32, 3, strides=2, padding=\"same\", activation=\"relu\")(x)\n", "outputs = layers.Conv2DTranspose(1, 3, padding=\"same\", activation=\"sigmoid\")(x)\n", "\n", "decoder = keras.Model(inputs, outputs, name=\"decoder\")\n", "decoder.compile()\n", "# decoder.summary()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### VAE\n", "Our loss function is the weighted sum of two values. \n", "`reconstruction_loss` which measures the loss during reconstruction. \n", "`kl_loss` which measures the dispersion. \n", "\n", "The weights are defined by: `r_loss_factor` : \n", "`total_loss = r_loss_factor*reconstruction_loss + (1-r_loss_factor)*kl_loss`\n", "\n", "if `r_loss_factor = 1`, the loss function includes only `reconstruction_loss` \n", "if `r_loss_factor = 0`, the loss function includes only `kl_loss` \n", "In practice, a value of 0.3 gives good results here.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "vae = VAE(encoder, decoder, r_loss_factor)\n", "\n", "vae.compile(optimizer=keras.optimizers.Adam())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Step 5 - Train\n", "With `scale=1`, need 1'15 on a GPU (V100 at IDRIS) ...or 20' on a CPU " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# ---- Callback : Images encoded\n", "pwk.mkdir(run_dir + '/images-encoded')\n", "filename = run_dir + '/images-encoded/image-{epoch:03d}-{i:02d}.jpg'\n", "callback_images1 = ImagesCallback(filename, x=x_train[:5], encoder=encoder,decoder=decoder)\n", "\n", "# ---- Callback : Images generated\n", "pwk.mkdir(run_dir + '/images-generated')\n", "filename = run_dir + '/images-generated/image-{epoch:03d}-{i:02d}.jpg'\n", "callback_images2 = ImagesCallback(filename, x=None, nb_images=5, z_dim=2, encoder=encoder,decoder=decoder) \n", "\n", "# ---- Callback : Best model\n", "pwk.mkdir(run_dir + '/models')\n", "filename = run_dir + '/models/best_model'\n", "callback_bestmodel = BestModelCallback(filename)\n", "\n", "# ---- Callback tensorboard\n", "dirname = run_dir + '/logs'\n", "callback_tensorboard = TensorBoard(log_dir=dirname, histogram_freq=1)\n", "\n", "# callbacks_list = [callback_images1, callback_images2, callback_bestmodel, callback_tensorboard]\n", "callbacks_list = [callback_images1, callback_images2, callback_bestmodel]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "pwk.chrono_start()\n", "\n", "history = vae.fit(x_train, epochs=epochs, batch_size=batch_size, callbacks=callbacks_list,)\n", "\n", "pwk.chrono_show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Step 6 - About our training session\n", "### 6.1 - History" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "pwk.plot_history(history, plot={\"Loss\":['loss','r_loss', 'kl_loss']}, save_as='history')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 6.2 - Reconstruction (input -> encoder -> decoder)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "imgs=[]\n", "labels=[]\n", "for epoch in range(0,epochs,2):\n", " for i in range(5):\n", " filename = f'{run_dir}/images-encoded/image-{epoch:03d}-{i:02d}.jpg'.format(epoch=epoch, i=i)\n", " img = io.imread(filename)\n", " imgs.append(img)\n", " \n", "\n", "pwk.subtitle('Original images :')\n", "pwk.plot_images(x_train[:5], None, indices='all', columns=5, x_size=2,y_size=2, save_as=None)\n", "\n", "pwk.subtitle('Encoded/decoded images')\n", "pwk.plot_images(imgs, None, indices='all', columns=5, x_size=2,y_size=2, save_as='02-reconstruct')\n", "\n", "pwk.subtitle('Original images :')\n", "pwk.plot_images(x_train[:5], None, indices='all', columns=5, x_size=2,y_size=2, save_as=None)\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 6.3 - Generation (latent -> decoder)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "imgs=[]\n", "labels=[]\n", "for epoch in range(0,epochs,2):\n", " for i in range(5):\n", " filename = f'{run_dir}/images-generated/image-{epoch:03d}-{i:02d}.jpg'.format(epoch=epoch, i=i)\n", " img = io.imread(filename)\n", " imgs.append(img)\n", " \n", "pwk.subtitle('Generated images from latent space')\n", "pwk.plot_images(imgs, None, indices='all', columns=5, x_size=2,y_size=2, save_as='03-encoded')\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "pwk.end()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "---\n", "<img width=\"80px\" src=\"../fidle/img/00-Fidle-logo-01.svg\"></img>" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.9" } }, "nbformat": 4, "nbformat_minor": 4 }