{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "<img width=\"800px\" src=\"../fidle/img/header.svg\"></img>\n", "\n", "# <!-- TITLE --> [K3VAE2] - VAE, using a custom model class (MNIST dataset)\n", "<!-- DESC --> Construction and training of a VAE, using model subclass, with a latent space of small dimension.\n", "\n", "<!-- AUTHOR : Jean-Luc Parouty (CNRS/SIMaP) -->\n", "\n", "## Objectives :\n", " - Understanding and implementing a **variational autoencoder** neurals network (VAE)\n", " - Understanding a still more **advanced programming model**, using a **custom model**\n", "\n", "The calculation needs being important, it is preferable to use a very simple dataset such as MNIST to start with. \n", "...MNIST with a small scale if you haven't a GPU ;-)\n", "\n", "## What we're going to do :\n", "\n", " - Defining a VAE model\n", " - Build the model\n", " - Train it\n", " - Have a look on the train process\n", "\n", "## Acknowledgements :\n", "Thanks to **François Chollet** who is at the base of this example (and the creator of Keras !!). \n", "See : https://keras.io/examples/generative/vae\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Step 1 - Init python stuff" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import os\n", "os.environ['KERAS_BACKEND'] = 'torch'\n", "\n", "import keras\n", "from keras import layers\n", "\n", "import numpy as np\n", "\n", "from modules.models import VAE\n", "from modules.layers import SamplingLayer\n", "from modules.callbacks import ImagesCallback\n", "from modules.datagen import MNIST\n", "\n", "\n", "import matplotlib.pyplot as plt\n", "import scipy.stats\n", "import sys\n", "\n", "import fidle\n", "\n", "# Init Fidle environment\n", "run_id, run_dir, datasets_dir = fidle.init('K3VAE2')\n", "\n", "VAE.about()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Step 2 - Parameters\n", "`scale` : with scale=1, we need 1'30s on a GPU V100 ...and >20' on a CPU ! \n", "`latent_dim` : 2 dimensions is small, but usefull to draw ! \n", "`fit_verbosity`: Verbosity of training progress bar: 0=silent, 1=progress bar, 2=One line \n", "\n", "`loss_weights` : Our **loss function** is the weighted sum of two loss:\n", " - `r_loss` which measures the loss during reconstruction. \n", " - `kl_loss` which measures the dispersion. \n", "\n", "The weights are defined by: `loss_weights=[k1,k2]` where : `total_loss = k1*r_loss + k2*kl_loss` \n", "In practice, a value of \\[1,.06\\] gives good results here.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "latent_dim = 6\n", "loss_weights = [1,.06]\n", "\n", "scale = .2\n", "seed = 123\n", "\n", "batch_size = 64\n", "epochs = 4\n", "fit_verbosity = 1" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Override parameters (batch mode) - Just forget this cell" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "fidle.override('latent_dim', 'loss_weights', 'scale', 'seed', 'batch_size', 'epochs', 'fit_verbosity')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Step 3 - Prepare data\n", "`MNIST.get_data()` return : `x_train,y_train, x_test,y_test`, \\\n", "but we only need x_train for our training." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "x_data, y_data, _,_ = MNIST.get_data(seed=seed, scale=scale, train_prop=1 )\n", "\n", "fidle.scrawler.images(x_data[:20], None, indices='all', columns=10, x_size=1,y_size=1,y_padding=0, save_as='01-original')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Step 4 - Build model\n", "In this example, we will use a **custom model**.\n", "For this, we will use :\n", " - `SamplingLayer`, which generates a vector z from the parameters z_mean and z_log_var - See : [SamplingLayer.py](./modules/layers/SamplingLayer.py)\n", " - `VAE`, a custom model with a specific train_step - See : [VAE.py](./modules/models/VAE.py)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Encoder" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "inputs = keras.Input(shape=(28, 28, 1))\n", "x = layers.Conv2D(32, 3, strides=1, padding=\"same\", activation=\"relu\")(inputs)\n", "x = layers.Conv2D(64, 3, strides=2, padding=\"same\", activation=\"relu\")(x)\n", "x = layers.Conv2D(64, 3, strides=2, padding=\"same\", activation=\"relu\")(x)\n", "x = layers.Conv2D(64, 3, strides=1, padding=\"same\", activation=\"relu\")(x)\n", "x = layers.Flatten()(x)\n", "x = layers.Dense(16, activation=\"relu\")(x)\n", "\n", "z_mean = layers.Dense(latent_dim, name=\"z_mean\")(x)\n", "z_log_var = layers.Dense(latent_dim, name=\"z_log_var\")(x)\n", "z = SamplingLayer()([z_mean, z_log_var])\n", "\n", "encoder = keras.Model(inputs, [z_mean, z_log_var, z], name=\"encoder\")\n", "encoder.compile()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Decoder" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "inputs = keras.Input(shape=(latent_dim,))\n", "x = layers.Dense(7 * 7 * 64, activation=\"relu\")(inputs)\n", "x = layers.Reshape((7, 7, 64))(x)\n", "x = layers.Conv2DTranspose(64, 3, strides=1, padding=\"same\", activation=\"relu\")(x)\n", "x = layers.Conv2DTranspose(64, 3, strides=2, padding=\"same\", activation=\"relu\")(x)\n", "x = layers.Conv2DTranspose(32, 3, strides=2, padding=\"same\", activation=\"relu\")(x)\n", "outputs = layers.Conv2DTranspose(1, 3, padding=\"same\", activation=\"sigmoid\")(x)\n", "\n", "decoder = keras.Model(inputs, outputs, name=\"decoder\")\n", "decoder.compile()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### VAE\n", "`VAE` is a custom model with a specific train_step - See : [VAE.py](./modules/models/VAE.py)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "vae = VAE(encoder, decoder, loss_weights)\n", "\n", "vae.compile(optimizer='adam')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Step 5 - Train\n", "### 5.1 - Using two nice custom callbacks :-)\n", "Two custom callbacks are used:\n", " - `ImagesCallback` : qui va sauvegarder des images durant l'apprentissage - See [ImagesCallback.py](./modules/callbacks/ImagesCallback.py)\n", " - `BestModelCallback` : qui sauvegardera le meilleur model - See [BestModelCallback.py](./modules/callbacks/BestModelCallback.py)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "callback_images = ImagesCallback(x=x_data, z_dim=latent_dim, nb_images=5, from_z=True, from_random=True, run_dir=run_dir)\n", "\n", "callbacks_list = [callback_images]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 5.2 - Let's train !\n", "With `scale=1`, need 1'15 on a GPU (V100 at IDRIS) ...or 20' on a CPU " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "chrono=fidle.Chrono()\n", "chrono.start()\n", "\n", "history = vae.fit(x_data, epochs=epochs, batch_size=batch_size, callbacks=callbacks_list, verbose=fit_verbosity)\n", "\n", "chrono.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Step 6 - Training review\n", "### 6.1 - History" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "fidle.scrawler.history(history, plot={\"Loss\":['loss']}, save_as='history')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 6.2 - Reconstruction during training\n", "At the end of each epoch, our callback saved some reconstructed images. \n", "Where : \n", "Original image -> encoder -> z -> decoder -> Reconstructed image" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "images_z, images_r = callback_images.get_images( range(0,epochs,2) )\n", "\n", "fidle.utils.subtitle('Original images :')\n", "fidle.scrawler.images(x_data[:5], None, indices='all', columns=5, x_size=2,y_size=2, save_as='02-original')\n", "\n", "fidle.utils.subtitle('Encoded/decoded images')\n", "fidle.scrawler.images(images_z, None, indices='all', columns=5, x_size=2,y_size=2, save_as='03-reconstruct')\n", "\n", "fidle.utils.subtitle('Original images :')\n", "fidle.scrawler.images(x_data[:5], None, indices='all', columns=5, x_size=2,y_size=2, save_as=None)\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 6.3 - Generation (latent -> decoder) during training" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "fidle.utils.subtitle('Generated images from latent space')\n", "fidle.scrawler.images(images_r, None, indices='all', columns=5, x_size=2,y_size=2, save_as='04-encoded')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 6.4 - Save model" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "os.makedirs(f'{run_dir}/models', exist_ok=True)\n", "\n", "vae.save(f'{run_dir}/models/vae_model.keras')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Step 7 - Model evaluation" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 7.1 - Reload model" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "vae=VAE()\n", "vae.reload(f'{run_dir}/models/vae_model.keras')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 7.2 - Image reconstruction" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# ---- Select few images\n", "\n", "x_show = fidle.utils.pick_dataset(x_data, n=10)\n", "\n", "# ---- Get latent points and reconstructed images\n", "\n", "z_mean, z_var, z = vae.encoder.predict(x_show)\n", "x_reconst = vae.decoder.predict(z)\n", "\n", "# ---- Show it\n", "\n", "labels=[ str(np.round(z[i],1)) for i in range(10) ]\n", "fidle.scrawler.images(x_show, None, indices='all', columns=10, x_size=2,y_size=2, save_as='05-original')\n", "fidle.scrawler.images(x_reconst, None, indices='all', columns=10, x_size=2,y_size=2, save_as='06-reconstruct')\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 7.3 - Visualization of the latent space" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "n_show = int(20000*scale)\n", "\n", "# ---- Select images\n", "\n", "x_show, y_show = fidle.utils.pick_dataset(x_data,y_data, n=n_show)\n", "\n", "# ---- Get latent points\n", "\n", "z_mean, z_var, z = vae.encoder.predict(x_show)\n", "\n", "# ---- Show them\n", "\n", "fig = plt.figure(figsize=(14, 10))\n", "plt.scatter(z[:, 0] , z[:, 1], c=y_show, cmap= 'tab10', alpha=0.5, s=30)\n", "plt.colorbar()\n", "fidle.scrawler.save_fig('07-Latent-space')\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 7.4 - Generative latent space" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "if latent_dim>2:\n", "\n", " print('Sorry, This part can only work if the latent space is of dimension 2')\n", "\n", "else:\n", " \n", " grid_size = 18\n", " grid_scale = 1\n", "\n", " # ---- Draw a ppf grid\n", "\n", " grid=[]\n", " for y in scipy.stats.norm.ppf(np.linspace(0.99, 0.01, grid_size),scale=grid_scale):\n", " for x in scipy.stats.norm.ppf(np.linspace(0.01, 0.99, grid_size),scale=grid_scale):\n", " grid.append( (x,y) )\n", " grid=np.array(grid)\n", "\n", " # ---- Draw latentspoints and grid\n", "\n", " fig = plt.figure(figsize=(10, 8))\n", " plt.scatter(z[:, 0] , z[:, 1], c=y_show, cmap= 'tab10', alpha=0.5, s=20)\n", " plt.scatter(grid[:, 0] , grid[:, 1], c = 'black', s=60, linewidth=2, marker='+', alpha=1)\n", " fidle.scrawler.save_fig('08-Latent-grid')\n", " plt.show()\n", "\n", " # ---- Plot grid corresponding images\n", "\n", " x_reconst = vae.decoder.predict([grid])\n", " fidle.scrawler.images(x_reconst, indices='all', columns=grid_size, x_size=0.5,y_size=0.5, y_padding=0,spines_alpha=0.1, save_as='09-Latent-morphing')\n", "\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "fidle.end()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "---\n", "<img width=\"80px\" src=\"../fidle/img/logo-paysage.svg\"></img>" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3.9.2 ('fidle-env')", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.2" }, "vscode": { "interpreter": { "hash": "b3929042cc22c1274d74e3e946c52b845b57cb6d84f2d591ffe0519b38e4896d" } } }, "nbformat": 4, "nbformat_minor": 4 }