Skip to content
Snippets Groups Projects
08-VAE-with-CelebA.ipynb 21.4 KiB
Newer Older
Jean-Luc Parouty's avatar
Jean-Luc Parouty committed
{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<img width=\"800px\" src=\"../fidle/img/00-Fidle-header-01.svg\"></img>\n",
    "\n",
    "# <!-- TITLE --> [VAE8] - Variational AutoEncoder (VAE) with CelebA (small)\n",
    "<!-- DESC --> Variational AutoEncoder (VAE) with CelebA (small res. 128x128)\n",
    "<!-- AUTHOR : Jean-Luc Parouty (CNRS/SIMaP) -->\n",
    "\n",
    "## Objectives :\n",
    " - Build and train a VAE model with a large dataset in **small resolution(>70 GB)**\n",
    " - Understanding a more advanced programming model with **data generator**\n",
    "\n",
    "The [CelebFaces Attributes Dataset (CelebA)](http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html) contains about 200,000 images (202599,218,178,3).  \n",
    "\n",
    "## What we're going to do :\n",
    "\n",
    " - Defining a VAE model\n",
    " - Build the model\n",
    " - Train it\n",
    " - Follow the learning process with Tensorboard\n",
    "\n",
    "## Acknowledgements :\n",
    "As before, thanks to **François Chollet** who is at the base of this example.  \n",
    "See : https://keras.io/examples/generative/vae\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 1 - Init python stuff"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
Jean-Luc Parouty's avatar
Jean-Luc Parouty committed
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<style>\n",
       "\n",
       "div.warn {    \n",
       "    background-color: #fcf2f2;\n",
       "    border-color: #dFb5b4;\n",
       "    border-left: 5px solid #dfb5b4;\n",
       "    padding: 0.5em;\n",
       "    font-weight: bold;\n",
       "    font-size: 1.1em;;\n",
       "    }\n",
       "\n",
       "\n",
       "\n",
       "div.nota {    \n",
       "    background-color: #DAFFDE;\n",
       "    border-left: 5px solid #92CC99;\n",
       "    padding: 0.5em;\n",
       "    }\n",
       "\n",
       "div.todo:before { content:url(data:image/svg+xml;base64,PHN2ZyB4bWxucz0iaHR0cDovL3d3dy53My5vcmcvMjAwMC9zdmciIHdpZHRoPSI1My44OTEyIiBoZWlnaHQ9IjE0My4zOTAyIiB2aWV3Qm94PSIwIDAgNTMuODkxMiAxNDMuMzkwMiI+PHRpdGxlPjAwLUJvYi10b2RvPC90aXRsZT48cGF0aCBkPSJNMjMuNDU2OCwxMTQuMzAxNmExLjgwNjMsMS44MDYzLDAsMSwxLDEuODE1NywxLjgyNEExLjgyMDksMS44MjA5LDAsMCwxLDIzLjQ1NjgsMTE0LjMwMTZabS0xMC42NjEyLDEuODIyQTEuODI3MiwxLjgyNzIsMCwxLDAsMTAuOTgsMTE0LjMsMS44MiwxLjgyLDAsMCwwLDEyLjc5NTYsMTE2LjEyMzZabS03LjcwNyw0LjU4NzR2LTVzLjQ4NjMtOS4xMjIzLDguMDIxNS0xMS45Njc1YTE5LjIwODIsMTkuMjA4MiwwLDAsMSw2LjA0ODYtMS4yNDU0LDE5LjE3NzgsMTkuMTc3OCwwLDAsMSw2LjA0ODcsMS4yNDc1YzcuNTM1MSwyLjgzNDcsOC4wMTc0LDExLjk2NzQsOC4wMTc0LDExLjk2NzR2NS4wMjM0bC4wMDQyLDcuNjgydjIuNGMuMDE2Ny4xOTkyLjAzMzYuMzkyMS4wMzM2LjU4NzEsMCwuMjEzOC0uMDE2OC40MTA5LS4wMzM2LjYzMzJ2LjA1ODdoLS4wMDg0YTguMzcxOSw4LjM3MTksMCwwLDEtNy4zNzM4LDcuNjU0N3MtLjk5NTMsMy42MzgtNi42OTMzLDMuNjM4LTYuNjkzNC0zLjYzOC02LjY5MzQtMy42MzhhOC4zNyw4LjM3LDAsMCwxLTcuMzcxNi03LjY1NDdINS4wODQzdi0uMDU4N2MtLjAxODktLjIyLS4wMjk0LS40MTk0LS4wMjk0LS42MzMyLDAtLjE5MjkuMDE2Ny0uMzgzNy4wMjk0LS41ODcxdi0yLjRtMTguMDkzNy00LjA0YTEuMTU2NSwxLjE1NjUsMCwxLDAtMi4zMTI2LDAsMS4xNTY0LDEuMTU2NCwwLDEsMCwyLjMxMjYsMFptNC4wODM0LDBhMS4xNTk1LDEuMTU5NSwwLDEsMC0xLjE2MzYsMS4xN0ExLjE3NSwxLjE3NSwwLDAsMCwyNy4yNjE0LDEyNC4zNzc5Wk05LjM3MzksMTE0LjYzNWMwLDMuMTA5MywyLjQxMzIsMy4zMSwyLjQxMzIsMy4zMWExMzMuOTI0MywxMzMuOTI0MywwLDAsMCwxNC43MzQ4LDBzMi40MTExLS4xOTI5LDIuNDExMS0zLjMxYTguMDc3Myw4LjA3NzMsMCwwLDAtMi40MTExLTUuNTUxOWMtNC41LTMuNTAzMy05LjkxMjYtMy41MDMzLTE0Ljc0MTEsMEE4LjA4NTEsOC4wODUxLDAsMCwwLDkuMzczOSwxMTQuNjM1WiIgc3R5bGU9ImZpbGw6IzAxMDEwMSIvPjxjaXJjbGUgY3g9IjMzLjE0MzYiIGN5PSIxMjQuNTM0IiByPSIzLjgzNjMiIHN0eWxlPSJmaWxsOiMwMTAxMDEiLz48cmVjdCB4PSIzNS42NjU5IiB5PSIxMTIuOTYyNSIgd2lkdGg9IjIuMDc3IiBoZWlnaHQ9IjEwLjU0NTgiIHRyYW5zZm9ybT0idHJhbnNsYXRlKDIxLjYgMjQxLjExMjEpIHJvdGF0ZSgtMTU1Ljc0NikiIHN0eWxlPSJmaWxsOiMwMTAxMDEiLz48Y2lyY2xlIGN4PSIzOC44NzA0IiBjeT0iMTEzLjQyNzkiIHI9IjIuNDA4NSIgc3R5bGU9ImZpbGw6IzAxMDEwMSIvPjxjaXJjbGUgY3g9IjUuMjI0OCIgY3k9IjEyNC41MzQiIHI9IjMuODM2MyIgc3R5bGU9ImZpbGw6IzAxMDEwMSIvPjxyZWN0IHg9IjEuNDE2NCIgeT0iMTI0LjYzMDEiIHdpZHRoPSIyLjA3NyIgaGVpZ2h0PSIxMC41NDU4IiB0cmFuc2Zvcm09InRyYW5zbGF0ZSg0LjkwOTcgMjU5LjgwNikgcm90YXRlKC0xODApIiBzdHlsZT0iZmlsbDojMDEwMTAxIi8+PGNpcmNsZSBjeD0iMi40MDkxIiBjeT0iMTM3LjA5OTYiIHI9IjIuNDA4NSIgc3R5bGU9ImZpbGw6IzAxMDEwMSIvPjxwYXRoIGQ9Ik0xOC4wNTExLDEwMC4xMDY2aC0uMDE0NlYxMDIuNjFoMi4zdi0yLjQyNzlhMi40MjI5LDIuNDIyOSwwLDEsMC0yLjI4NTQtLjA3NTVaIiBzdHlsZT0iZmlsbDojMDEwMTAxIi8+PHBhdGggZD0iTTM5LjQyMTQsMjcuMjU4djEuMDVBMTEuOTQ1MiwxMS45NDUyLDAsMCwwLDQ0LjU5NTQsNS43OWEuMjQ0OS4yNDQ5LDAsMCwxLS4wMjM1LS40MjI3TDQ2Ljc1LDMuOTUxNWEuMzg5Mi4zODkyLDAsMCwxLC40MjYyLDAsMTQuODQ0MiwxNC44NDQyLDAsMCwxLTcuNzU0MywyNy4yNTkxdjEuMDY3YS40NS40NSwwLDAsMS0uNzA0Ny4zNzU4bC0zLjg0MTktMi41MWEuNDUuNDUsMCwwLDEsMC0uNzUxNmwzLjg0MTktMi41MWEuNDUuNDUsMCwwLDEsLjY5NDYuMzc1OFpNNDMuMjMsMi41ODkyLDM5LjM4NzguMDc5NGEuNDUuNDUsMCwwLDAtLjcwNDYuMzc1OHYxLjA2N2ExNC44NDQyLDE0Ljg0NDIsMCwwLDAtNy43NTQzLDI3LjI1OTEuMzg5LjM4OSwwLDAsMCwuNDI2MSwwbDIuMTc3Ny0xLjQxOTNhLjI0NS4yNDUsMCwwLDAtLjAyMzUtLjQyMjgsMTEuOTQ1MSwxMS45NDUxLDAsMCwxLDUuMTc0LTIyLjUxNDZ2MS4wNWEuNDUuNDUsMCwwLDAsLjcwNDYuMzc1OGwzLjg1NTMtMi41MWEuNDUuNDUsMCwwLDAsMC0uNzUxNlpNMzkuMDUyMywxNC4yNDU4YTIuMTIwNiwyLjEyMDYsMCwxLDAsMi4xMjA2LDIuMTIwNmgwQTIuMTI0LDIuMTI0LDAsMCwwLDM5LjA1MjMsMTQuMjQ1OFptNi4wNzMyLTQuNzc4MS44MjU0LjgyNTVhMS4wNTY4LDEuMDU2OCwwLDAsMSwuMTE3NSwxLjM0MjFsLS44MDIsMS4xNDQyYTcuMTAxOCw3LjEwMTgsMCwwLDEsLjcxMTQsMS43MTEybDEuMzc1Ny4yNDE2YTEuMDU2OSwxLjA1NjksMCwwLDEsLjg3NTcsMS4wNHYxLjE2NDNhMS4wNTY5LDEuMDU2OSwwLDAsMS0uODc1NywxLjA0bC0xLjM3MjQuMjQxNkE3LjExLDcuMTEsMCwwLDEsNDUuMjcsMTkuOTNsLjgwMTksMS4xNDQyYTEuMDU3LDEuMDU3LDAsMCwxLS4xMTc0LDEuMzQyMmwtLjgyODguODQ4OWExLjA1NywxLjA1NywwLDAsMS0xLjM0MjEuMTE3NGwtMS4xNDQyLS44MDE5YTcuMTMzOCw3LjEzMzgsMCwwLDEtMS43MTEzLjcxMTNsLS4yNDE2LDEuMzcyNGExLjA1NjgsMS4wNTY4LDAsMCwxLTEuMDQuODc1N0gzOC40Njg0YTEuMDU2OCwxLjA1NjgsMCwwLDEtMS4wNC0uODc1N2wtLjI0MTYtMS4zNzI0YTcuMTM1NSw3LjEzNTUsMCwwLDEtMS43MTEzLS43MTEzbC0xLjE0NDEuODAxOWExLjA1NzEsMS4wNTcxLDAsMCwxLTEuMzQyMi0uMTE3NGwtLjgzNTUtLjgyNTVhMS4wNTcsMS4wNTcsMCwwLDEtLjExNzQtMS4zNDIxbC44MDE5LTEuMTQ0MmE3LjEyMSw3LjEyMSwwLDAsMS0uNzExMy0xLjcxMTJsLTEuMzcyNC0uMjQxNmExLjA1NjksMS4wNTY5LDAsMCwxLS44NzU3LTEuMDRWMTUuNzgyNmExLjA1NjksMS4wNTY5LDAsMCwxLC44NzU3LTEuMDRsMS4zNzU3LS4yNDE2YTcuMTEsNy4xMSwwLDAsMSwuNzExNC0xLjcxMTJsLS44MDItMS4xNDQyYTEuMDU3LDEuMDU3LDAsMCwxLC4xMTc1LTEuMzQyMmwuODI1NC0uODI1NEExLjA1NjgsMS4wNTY4LDAsMCwxLDM0LjMyNDUsOS4zNmwxLjE0NDIuODAxOUE3LjEzNTUsNy4xMzU1LDAsMCwxLDM3LjE4LDkuNDUxbC4yNDE2LTEuMzcyNGExLjA1NjgsMS4wNTY4LDAsMCwxLDEuMDQtLjg3NTdoMS4xNjc3YTEuMDU2OSwxLjA1NjksMCwwLDEsMS4wNC44NzU3bC4yNDE2LDEuMzcyNGE3LjEyNSw3LjEyNSwwLDAsMSwxLjcxMTIuNzExM0w0My43NjY2LDkuMzZBMS4wNTY5LDEuMDU2OSwwLDAsMSw0NS4xMjU1LDkuNDY3N1ptLTIuMDMsNi44OTg3QTQuMDQzMyw0LjA0MzMsMCwxLDAsMzkuMDUyMywyMC40MWgwQTQuMDQ2NSw0LjA0NjUsMCwwLDAsNDMuMDk1NSwxNi4zNjY0WiIgc3R5bGU9ImZpbGw6I2UxMjIyOSIvPjxwb2x5Z29uIHBvaW50cz0iMzkuNDEzIDM0Ljc1NyAzOS41MzcgMzQuNzU3IDM5LjY3NSAzNC43NTcgMzkuNjc1IDEwOS41MSAzOS41MzcgMTA5LjUxIDM5LjQxMyAxMDkuNTEgMzkuNDEzIDM0Ljc1NyAzOS40MTMgMzQuNzU3IiBzdHlsZT0iZmlsbDpub25lO3N0cm9rZTojOTk5O3N0cm9rZS1saW5lY2FwOnJvdW5kO3N0cm9rZS1taXRlcmxpbWl0OjEwO3N0cm9rZS13aWR0aDowLjMwODg1NDQ1MDU2MDE2MThweDtmaWxsLXJ1bGU6ZXZlbm9kZCIvPjwvc3ZnPg==);\n",
       "    float:left;\n",
       "    margin-right:20px;\n",
       "    margin-top:-20px;\n",
       "    margin-bottom:20px;\n",
       "}\n",
       "div.todo{\n",
       "    font-weight: bold;\n",
       "    font-size: 1.1em;\n",
       "    margin-top:40px;\n",
       "}\n",
       "div.todo ul{\n",
       "    margin: 0.2em;\n",
       "}\n",
       "div.todo li{\n",
       "    margin-left:60px;\n",
       "    margin-top:0;\n",
       "    margin-bottom:0;\n",
       "}\n",
       "\n",
       "div .comment{\n",
       "    font-size:0.8em;\n",
       "    color:#696969;\n",
       "}\n",
       "\n",
       "\n",
       "\n",
       "</style>\n",
       "\n"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Override : Attribute [run_dir=./run/CelebA.001] with [./run/test-VAE8-3370]\n"
     ]
    },
Jean-Luc Parouty's avatar
Jean-Luc Parouty committed
    {
     "data": {
      "text/markdown": [
       "**FIDLE 2020 - Practical Work Module**"
      ],
      "text/plain": [
       "<IPython.core.display.Markdown object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Version              : 0.6.1 DEV\n",
      "Notebook id          : VAE8\n",
      "Run time             : Wednesday 6 January 2021, 19:47:34\n",
      "TensorFlow version   : 2.2.0\n",
      "Keras version        : 2.3.0-tf\n",
      "Datasets dir         : /home/pjluc/datasets/fidle\n",
      "Run dir              : ./run/test-VAE8-3370\n",
Jean-Luc Parouty's avatar
Jean-Luc Parouty committed
      "Update keras cache   : False\n",
      "Save figs            : True\n",
      "Path figs            : ./run/test-VAE8-3370/figs\n"
Jean-Luc Parouty's avatar
Jean-Luc Parouty committed
     ]
    },
    {
     "data": {
      "text/markdown": [
       "<br>**FIDLE 2021 - VAE**"
      ],
      "text/plain": [
       "<IPython.core.display.Markdown object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Version              : 1.2\n",
      "TensorFlow version   : 2.2.0\n",
      "Keras version        : 2.3.0-tf\n"
Jean-Luc Parouty's avatar
Jean-Luc Parouty committed
     ]
    },
    {
     "data": {
      "text/markdown": [
       "<br>**FIDLE 2020 - DataGenerator**"
      ],
      "text/plain": [
       "<IPython.core.display.Markdown object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Version              : 0.4.1\n",
      "TensorFlow version   : 2.2.0\n",
      "Keras version        : 2.3.0-tf\n"
Jean-Luc Parouty's avatar
Jean-Luc Parouty committed
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "from skimage import io\n",
    "\n",
    "import tensorflow as tf\n",
    "from tensorflow import keras\n",
    "from tensorflow.keras import layers\n",
    "from tensorflow.keras.callbacks import ModelCheckpoint, TensorBoard\n",
    "\n",
    "import os,sys,json,time,datetime\n",
    "from IPython.display import display,Image,Markdown,HTML\n",
    "\n",
    "from modules.data_generator import DataGenerator\n",
    "from modules.VAE            import VAE, Sampling\n",
    "from modules.callbacks      import ImagesCallback, BestModelCallback\n",
    "\n",
    "sys.path.append('..')\n",
    "import fidle.pwk as pwk\n",
    "\n",
    "run_dir = './run/CelebA.001'                  # Output directory\n",
Jean-Luc Parouty's avatar
Jean-Luc Parouty committed
    "datasets_dir = pwk.init('VAE8', run_dir)\n",
    "\n",
    "VAE.about()\n",
    "DataGenerator.about()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
Jean-Luc Parouty's avatar
Jean-Luc Parouty committed
   "metadata": {},
Jean-Luc Parouty's avatar
Jean-Luc Parouty committed
   "source": [
    "# To clean run_dir, uncomment and run this next line\n",
    "# ! rm -r \"$run_dir\"/images-* \"$run_dir\"/logs \"$run_dir\"/figs \"$run_dir\"/models ; rmdir \"$run_dir\""
Jean-Luc Parouty's avatar
Jean-Luc Parouty committed
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 2 - Get some data\n",
    "Let's instantiate our generator for the entire dataset."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 1.1 - Parameters\n",
    "Uncomment the right lines according to the data you want to use"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
Jean-Luc Parouty's avatar
Jean-Luc Parouty committed
   "metadata": {},
   "outputs": [],
   "source": [
    "# ---- For tests\n",
    "scale         = 0.3\n",
    "image_size    = (128,128)\n",
    "enhanced_dir  = './data'\n",
    "latent_dim    = 300\n",
    "r_loss_factor = 0.6\n",
Jean-Luc Parouty's avatar
Jean-Luc Parouty committed
    "batch_size    = 64\n",
    "epochs        = 15\n",
Jean-Luc Parouty's avatar
Jean-Luc Parouty committed
    "\n",
    "# ---- Training with a full dataset\n",
    "# scale         = 1.\n",
    "# image_size    = (128,128)\n",
    "# enhanced_dir  = f'{datasets_dir}/celeba/enhanced'\n",
    "# latent_dim    = 300\n",
    "# r_loss_factor = 0.6\n",
Jean-Luc Parouty's avatar
Jean-Luc Parouty committed
    "# batch_size    = 64\n",
    "# epochs        = 15\n",
Jean-Luc Parouty's avatar
Jean-Luc Parouty committed
    "\n",
    "# ---- Training with a full dataset of large images\n",
    "# scale         = 1.\n",
    "# image_size    = (192,160)\n",
    "# enhanced_dir  = f'{datasets_dir}/celeba/enhanced'\n",
    "# latent_dim    = 300\n",
Jean-Luc Parouty's avatar
Jean-Luc Parouty committed
    "# r_loss_factor = 0.6\n",
    "# batch_size    = 64\n",
    "# epochs        = 15"
Jean-Luc Parouty's avatar
Jean-Luc Parouty committed
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 1.2 - Finding the right place"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
Jean-Luc Parouty's avatar
Jean-Luc Parouty committed
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train directory is : ./data/clusters-128x128\n"
Jean-Luc Parouty's avatar
Jean-Luc Parouty committed
     ]
    }
   ],
   "source": [
    "# ---- Override parameters (batch mode) - Just forget this line\n",
Jean-Luc Parouty's avatar
Jean-Luc Parouty committed
    "#\n",
Jean-Luc Parouty's avatar
Jean-Luc Parouty committed
    "pwk.override('scale', 'image_size', 'enhanced_dir', 'latent_dim', 'r_loss_factor', 'batch_size', 'epochs')\n",
Jean-Luc Parouty's avatar
Jean-Luc Parouty committed
    "\n",
    "# ---- the place of the clusters files\n",
    "#\n",
    "lx,ly      = image_size\n",
    "train_dir  = f'{enhanced_dir}/clusters-{lx}x{ly}'\n",
    "print('Train directory is :',train_dir)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 1.2 - Get a DataGenerator"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
Jean-Luc Parouty's avatar
Jean-Luc Parouty committed
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Data generator is ready with : 379 batchs of 32 images, or 12155 images\n"
Jean-Luc Parouty's avatar
Jean-Luc Parouty committed
     ]
    }
   ],
   "source": [
    "data_gen = DataGenerator(train_dir, 32, k_size=scale)\n",
    "\n",
    "print(f'Data generator is ready with : {len(data_gen)} batchs of {data_gen.batch_size} images, or {data_gen.dataset_size} images')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 3 - Build model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Encoder"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
Jean-Luc Parouty's avatar
Jean-Luc Parouty committed
   "metadata": {},
   "outputs": [],
   "source": [
    "inputs    = keras.Input(shape=(lx, ly, 3))\n",
    "x         = layers.Conv2D(32, 3, strides=2, padding=\"same\", activation=\"relu\")(inputs)\n",
    "x         = layers.Conv2D(64, 3, strides=2, padding=\"same\", activation=\"relu\")(x)\n",
    "x         = layers.Conv2D(64, 3, strides=2, padding=\"same\", activation=\"relu\")(x)\n",
    "x         = layers.Conv2D(64, 3, strides=2, padding=\"same\", activation=\"relu\")(x)\n",
    "\n",
    "shape_before_flattening = keras.backend.int_shape(x)[1:]\n",
    "\n",
    "x         = layers.Flatten()(x)\n",
    "x         = layers.Dense(512, activation=\"relu\")(x)\n",
    "\n",
    "z_mean    = layers.Dense(latent_dim, name=\"z_mean\")(x)\n",
    "z_log_var = layers.Dense(latent_dim, name=\"z_log_var\")(x)\n",
    "z         = Sampling()([z_mean, z_log_var])\n",
    "\n",
    "encoder = keras.Model(inputs, [z_mean, z_log_var, z], name=\"encoder\")\n",
    "encoder.compile()\n",
    "# encoder.summary()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Decoder"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
Jean-Luc Parouty's avatar
Jean-Luc Parouty committed
   "metadata": {},
   "outputs": [],
   "source": [
    "inputs  = keras.Input(shape=(latent_dim,))\n",
    "\n",
    "x = layers.Dense(np.prod(shape_before_flattening))(inputs)\n",
    "x = layers.Reshape(shape_before_flattening)(x)\n",
    "\n",
    "x       = layers.Conv2DTranspose(64, 3, strides=2, padding=\"same\", activation=\"relu\")(x)\n",
    "x       = layers.Conv2DTranspose(64, 3, strides=2, padding=\"same\", activation=\"relu\")(x)\n",
    "x       = layers.Conv2DTranspose(64, 3, strides=2, padding=\"same\", activation=\"relu\")(x)\n",
    "x       = layers.Conv2DTranspose(32, 3, strides=2, padding=\"same\", activation=\"relu\")(x)\n",
    "outputs = layers.Conv2DTranspose(3,  3, padding=\"same\", activation=\"sigmoid\")(x)\n",
    "\n",
    "decoder = keras.Model(inputs, outputs, name=\"decoder\")\n",
    "decoder.compile()\n",
    "# decoder.summary()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### VAE\n",
    "Our loss function is the weighted sum of two values.  \n",
    "`reconstruction_loss` which measures the loss during reconstruction.  \n",
    "`kl_loss` which measures the dispersion.  \n",
    "\n",
    "The weights are defined by: `r_loss_factor` :  \n",
    "`total_loss = r_loss_factor*reconstruction_loss + (1-r_loss_factor)*kl_loss`\n",
    "\n",
    "if `r_loss_factor = 1`, the loss function includes only `reconstruction_loss`  \n",
    "if `r_loss_factor = 0`, the loss function includes only `kl_loss`  \n",
    "In practice, a value arround 0.5 gives good results here.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
Jean-Luc Parouty's avatar
Jean-Luc Parouty committed
   "metadata": {},
Jean-Luc Parouty's avatar
Jean-Luc Parouty committed
   "source": [
    "vae = VAE(encoder, decoder, r_loss_factor)\n",
Jean-Luc Parouty's avatar
Jean-Luc Parouty committed
    "\n",
    "vae.compile(optimizer=keras.optimizers.Adam())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 4 - Train\n",
    "20' on a CPU  \n",
    "1'12 on a GPU (V100, IDRIS)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 4.1 - Callbacks"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
Jean-Luc Parouty's avatar
Jean-Luc Parouty committed
   "metadata": {},
   "outputs": [],
   "source": [
    "x_draw,_   = data_gen[0]\n",
    "data_gen.rewind()\n",
    "\n",
    "# ---- Callback : Images encoded\n",
    "pwk.mkdir(run_dir + '/images-encoded')\n",
    "filename = run_dir + '/images-encoded/image-{epoch:03d}-{i:02d}.jpg'\n",
    "callback_images1 = ImagesCallback(filename, x=x_draw[:5], encoder=encoder,decoder=decoder)\n",
    "\n",
    "# ---- Callback : Images generated\n",
    "pwk.mkdir(run_dir + '/images-generated')\n",
    "filename = run_dir + '/images-generated/image-{epoch:03d}-{i:02d}.jpg'\n",
    "callback_images2 = ImagesCallback(filename, x=None, nb_images=5, z_dim=latent_dim, encoder=encoder,decoder=decoder)          \n",
    "\n",
    "# ---- Callback : Best model\n",
    "pwk.mkdir(run_dir + '/models')\n",
    "filename = run_dir + '/models/best_model'\n",
    "callback_bestmodel = BestModelCallback(filename)\n",
    "\n",
    "# ---- Callback tensorboard\n",
    "dirname = run_dir + '/logs'\n",
    "callback_tensorboard = TensorBoard(log_dir=dirname, histogram_freq=1)\n",
    "\n",
    "callbacks_list = [callback_images1, callback_images2, callback_bestmodel, callback_tensorboard]\n",
    "callbacks_list = [callback_images1, callback_images2, callback_bestmodel]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 4.2 - Train it"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
Jean-Luc Parouty's avatar
Jean-Luc Parouty committed
   "metadata": {},
Jean-Luc Parouty's avatar
Jean-Luc Parouty committed
   "source": [
    "pwk.chrono_start()\n",
    "\n",
    "history = vae.fit(data_gen, epochs=epochs, batch_size=batch_size, callbacks=callbacks_list)\n",
    "\n",
    "pwk.chrono_show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 5 - About our training session\n",
    "### 5.1 - History"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
Jean-Luc Parouty's avatar
Jean-Luc Parouty committed
   "metadata": {},
Jean-Luc Parouty's avatar
Jean-Luc Parouty committed
   "source": [
    "pwk.plot_history(history,  plot={\"Loss\":['loss','r_loss', 'kl_loss']}, save_as='01-history')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 5.2 - Reconstruction (input -> encoder -> decoder)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
Jean-Luc Parouty's avatar
Jean-Luc Parouty committed
   "metadata": {},
Jean-Luc Parouty's avatar
Jean-Luc Parouty committed
   "source": [
    "imgs=[]\n",
    "labels=[]\n",
    "for epoch in range(1,epochs,1):\n",
    "    for i in range(5):\n",
    "        filename = f'{run_dir}/images-encoded/image-{epoch:03d}-{i:02d}.jpg'.format(epoch=epoch, i=i)\n",
    "        img      = io.imread(filename)\n",
    "        imgs.append(img)\n",
    "        \n",
    "\n",
    "pwk.subtitle('Original images :')\n",
    "pwk.plot_images(x_draw[:5], None, indices='all', columns=5, x_size=2,y_size=2, save_as='02-original')\n",
    "\n",
    "pwk.subtitle('Encoded/decoded images')\n",
    "pwk.plot_images(imgs, None, indices='all', columns=5, x_size=2,y_size=2, save_as='03-reconstruct')\n",
    "\n",
    "pwk.subtitle('Original images :')\n",
    "pwk.plot_images(x_draw[:5], None, indices='all', columns=5, x_size=2,y_size=2, save_as=None)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 5.3 Generation (latent -> decoder)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
Jean-Luc Parouty's avatar
Jean-Luc Parouty committed
   "metadata": {},
Jean-Luc Parouty's avatar
Jean-Luc Parouty committed
   "source": [
    "imgs=[]\n",
    "labels=[]\n",
    "for epoch in range(1,epochs,1):\n",
    "    for i in range(5):\n",
    "        filename = f'{run_dir}/images-generated/image-{epoch:03d}-{i:02d}.jpg'.format(epoch=epoch, i=i)\n",
    "        img      = io.imread(filename)\n",
    "        imgs.append(img)\n",
    "        \n",
    "pwk.subtitle('Generated images from latent space')\n",
    "pwk.plot_images(imgs, None, indices='all', columns=5, x_size=2,y_size=2, save_as='04-encoded')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
Jean-Luc Parouty's avatar
Jean-Luc Parouty committed
   "metadata": {},
Jean-Luc Parouty's avatar
Jean-Luc Parouty committed
   "source": [
    "pwk.end()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "---\n",
    "<img width=\"80px\" src=\"../fidle/img/00-Fidle-logo-01.svg\"></img>"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}