{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "Variational AutoEncoder (VAE) with CelebA\n", "=========================================\n", "---\n", "Formation Introduction au Deep Learning (FIDLE) - S. Arias, E. Maldonado, JL. Parouty - CNRS/SARI/DEVLOG - 2020 \n", "\n", "## Episode 1 - Train a model\n", "\n", " - Defining a VAE model\n", " - Build the model\n", " - Train it\n", " - Follow the learning process with Tensorboard\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Step 1 - Setup environment\n", "### 1.1 - Python stuff" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "data": { "text/html": [ "<style>\n", "\n", "div.warn { \n", " background-color: #fcf2f2;\n", " border-color: #dFb5b4;\n", " border-left: 5px solid #dfb5b4;\n", " padding: 0.5em;\n", " font-weight: bold;\n", " font-size: 1.1em;;\n", " }\n", "\n", "\n", "\n", "div.nota { \n", " background-color: #DAFFDE;\n", " border-left: 5px solid #92CC99;\n", " padding: 0.5em;\n", " }\n", "\n", "\n", "\n", "</style>\n", "\n" ], "text/plain": [ "<IPython.core.display.HTML object>" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "FIDLE 2020 - Practical Work Module\n", "Version : 0.2.8\n", "Run time : Friday 14 February 2020, 00:07:28\n", "TensorFlow version : 2.0.0\n", "Keras version : 2.2.4-tf\n", "\n", "FIDLE 2020 - Variational AutoEncoder (VAE)\n", "TensorFlow version : 2.0.0\n", "VAE version : 1.28\n", "\n", "FIDLE 2020 - DataGenerator\n", "Version : 0.4.1\n" ] } ], "source": [ "import tensorflow as tf\n", "import numpy as np\n", "import os,sys\n", "from importlib import reload\n", "\n", "import modules.vae\n", "import modules.data_generator\n", "\n", "reload(modules.data_generator)\n", "reload(modules.vae)\n", "\n", "from modules.vae import VariationalAutoencoder\n", "from modules.data_generator import DataGenerator\n", "\n", "sys.path.append('..')\n", "import fidle.pwk as ooo\n", "reload(ooo)\n", "\n", "ooo.init()\n", "\n", "VariationalAutoencoder.about()\n", "DataGenerator.about()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 1.2 - The good place" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Well, we should be at IDRIS !\n", "We are going to use: /gpfswork/rech/mlh/uja62cb/datasets/celeba\n" ] } ], "source": [ "place, dataset_dir = ooo.good_place( { 'GRICAD' : f'{os.getenv(\"SCRATCH_DIR\",\"\")}/PROJECTS/pr-fidle/datasets/celeba',\n", " 'IDRIS' : f'{os.getenv(\"WORK\",\"\")}/datasets/celeba',\n", " 'HOME' : f'{os.getenv(\"HOME\",\"\")}/datasets/celeba'} )\n", "\n", "# ---- train/test datasets\n", "\n", "train_dir = f'{dataset_dir}/clusters-M.train'\n", "test_dir = f'{dataset_dir}/clusters-M.test'" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Step 2 - DataGenerator and validation data\n", "Ok, everything's perfect, now let's instantiate our generator for the entire dataset." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Data generator : 6250 batchs of 32 images, or 200000 images\n", "x_test : 2599 images\n" ] } ], "source": [ "data_gen = DataGenerator(train_dir, 32, k_size=1)\n", "x_test = np.load(f'{test_dir}/images-000.npy')\n", "\n", "print(f'Data generator : {len(data_gen)} batchs of {data_gen.batch_size} images, or {data_gen.dataset_size} images')\n", "print(f'x_test : {len(x_test)} images')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Step 3 - Get VAE model" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Model initialized.\n", "Outputs will be in : ./run/CelebA.052-M.810156\n", "Config saved in : ./run/CelebA.052-M.810156/models/vae_config.json\n" ] } ], "source": [ "tag = f'CelebA.052-M.{os.getenv(\"SLURM_JOB_ID\",\"unknown\")}'\n", "\n", "input_shape = (192, 160, 3)\n", "z_dim = 200\n", "verbose = 0\n", "\n", "encoder= [ {'type':'Conv2D', 'filters':32, 'kernel_size':(3,3), 'strides':2, 'padding':'same', 'activation':'relu'},\n", " {'type':'Dropout', 'rate':0.25},\n", " {'type':'Conv2D', 'filters':64, 'kernel_size':(3,3), 'strides':2, 'padding':'same', 'activation':'relu'},\n", " {'type':'Dropout', 'rate':0.25},\n", " {'type':'Conv2D', 'filters':64, 'kernel_size':(3,3), 'strides':2, 'padding':'same', 'activation':'relu'},\n", " {'type':'Dropout', 'rate':0.25},\n", " {'type':'Conv2D', 'filters':64, 'kernel_size':(3,3), 'strides':2, 'padding':'same', 'activation':'relu'},\n", " {'type':'Dropout', 'rate':0.25},\n", " ]\n", "\n", "decoder= [ {'type':'Conv2DTranspose', 'filters':64, 'kernel_size':(3,3), 'strides':2, 'padding':'same', 'activation':'relu'},\n", " {'type':'Dropout', 'rate':0.25},\n", " {'type':'Conv2DTranspose', 'filters':64, 'kernel_size':(3,3), 'strides':2, 'padding':'same', 'activation':'relu'},\n", " {'type':'Dropout', 'rate':0.25},\n", " {'type':'Conv2DTranspose', 'filters':32, 'kernel_size':(3,3), 'strides':2, 'padding':'same', 'activation':'relu'},\n", " {'type':'Dropout', 'rate':0.25},\n", " {'type':'Conv2DTranspose', 'filters':3, 'kernel_size':(3,3), 'strides':2, 'padding':'same', 'activation':'sigmoid'}\n", " ]\n", "\n", "vae = modules.vae.VariationalAutoencoder(input_shape = input_shape, \n", " encoder_layers = encoder, \n", " decoder_layers = decoder,\n", " z_dim = z_dim, \n", " verbose = verbose,\n", " run_tag = tag)\n", "vae.save(model=None)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Step 4 - Compile it" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Compiled.\n" ] } ], "source": [ "optimizer = tf.keras.optimizers.Adam(1e-4)\n", "r_loss_factor = 10000\n", "\n", "vae.compile(optimizer, r_loss_factor)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Step 5 - Train\n", "For 10 epochs, adam optimizer : \n", "- Run time at IDRIS : 1299.77 sec. - 0:21:39\n", "- Run time at GRICAD : 2092.77 sec. - 0:34:52\n", "- At IDRIS with medium resolution : Train duration : 6638.61 sec. - 1:50:38" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "epochs = 20\n", "initial_epoch = 0" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1/20\n", "6250/6250 [==============================] - 342s 55ms/step - loss: 332.4792 - vae_r_loss: 282.9655 - vae_kl_loss: 49.5134 - val_loss: 235.4288 - val_vae_r_loss: 187.4334 - val_vae_kl_loss: 48.1192\n", "Epoch 2/20\n", "6250/6250 [==============================] - 337s 54ms/step - loss: 224.0962 - vae_r_loss: 166.4602 - vae_kl_loss: 57.6360 - val_loss: 210.9632 - val_vae_r_loss: 155.0925 - val_vae_kl_loss: 56.0114\n", "Epoch 4/20\n", "6250/6250 [==============================] - 333s 53ms/step - loss: 214.5463 - vae_r_loss: 155.7666 - vae_kl_loss: 58.7794 - val_loss: 203.8241 - val_vae_r_loss: 147.3248 - val_vae_kl_loss: 56.5778\n", "Epoch 7/20\n", "6250/6250 [==============================] - 327s 52ms/step - loss: 211.1459 - vae_r_loss: 152.4026 - vae_kl_loss: 58.7437 - val_loss: 201.1862 - val_vae_r_loss: 145.6906 - val_vae_kl_loss: 55.6326\n", "Epoch 10/20\n", "6250/6250 [==============================] - 335s 54ms/step - loss: 209.7628 - vae_r_loss: 151.0874 - vae_kl_loss: 58.6756 - val_loss: 202.3954 - val_vae_r_loss: 147.0956 - val_vae_kl_loss: 55.4047\n", "Epoch 12/20\n", "6250/6250 [==============================] - 333s 53ms/step - loss: 207.9830 - vae_r_loss: 149.3870 - vae_kl_loss: 58.5959 - val_loss: 198.5626 - val_vae_r_loss: 142.5848 - val_vae_kl_loss: 56.0871\n", "Epoch 16/20\n", "6250/6250 [==============================] - 330s 53ms/step - loss: 206.6382 - vae_r_loss: 148.0522 - vae_kl_loss: 58.5863 - val_loss: 197.5800 - val_vae_r_loss: 142.6799 - val_vae_kl_loss: 54.9832\n", "\n", "Train duration : 6638.61 sec. - 1:50:38\n" ] } ], "source": [ "vae.train(data_generator = data_gen,\n", " x_test = x_test,\n", " epochs = epochs,\n", " initial_epoch = initial_epoch\n", " )" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "----\n", "That's all folks !" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.6" } }, "nbformat": 4, "nbformat_minor": 4 }