Skip to content
Snippets Groups Projects
07-VAE-with-CelebA-m.nbconvert-done.ipynb 10.6 KiB
Newer Older
{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Variational AutoEncoder (VAE) with CelebA\n",
    "=========================================\n",
    "---\n",
    "Formation Introduction au Deep Learning  (FIDLE) - S. Arias, E. Maldonado, JL. Parouty - CNRS/SARI/DEVLOG - 2020  \n",
    "\n",
    "## Episode 1 - Train a model\n",
    "\n",
    " - Defining a VAE model\n",
    " - Build the model\n",
    " - Train it\n",
    " - Follow the learning process with Tensorboard\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 1 - Setup environment\n",
    "### 1.1 - Python stuff"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<style>\n",
       "\n",
       "div.warn {    \n",
       "    background-color: #fcf2f2;\n",
       "    border-color: #dFb5b4;\n",
       "    border-left: 5px solid #dfb5b4;\n",
       "    padding: 0.5em;\n",
       "    font-weight: bold;\n",
       "    font-size: 1.1em;;\n",
       "    }\n",
       "\n",
       "\n",
       "\n",
       "div.nota {    \n",
       "    background-color: #DAFFDE;\n",
       "    border-left: 5px solid #92CC99;\n",
       "    padding: 0.5em;\n",
       "    }\n",
       "\n",
       "\n",
       "\n",
       "</style>\n",
       "\n"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "FIDLE 2020 - Practical Work Module\n",
      "Version              : 0.2.8\n",
      "Run time             : Friday 14 February 2020, 00:07:28\n",
      "TensorFlow version   : 2.0.0\n",
      "Keras version        : 2.2.4-tf\n",
      "\n",
      "FIDLE 2020 - Variational AutoEncoder (VAE)\n",
      "TensorFlow version   : 2.0.0\n",
      "VAE version          : 1.28\n",
      "\n",
      "FIDLE 2020 - DataGenerator\n",
      "Version              : 0.4.1\n"
   "source": [
    "import tensorflow as tf\n",
    "import numpy as np\n",
    "import os,sys\n",
    "from importlib import reload\n",
    "\n",
    "import modules.vae\n",
    "import modules.data_generator\n",
    "\n",
    "reload(modules.data_generator)\n",
    "reload(modules.vae)\n",
    "\n",
    "from modules.vae  import VariationalAutoencoder\n",
    "from modules.data_generator import DataGenerator\n",
    "\n",
    "sys.path.append('..')\n",
    "import fidle.pwk as ooo\n",
    "reload(ooo)\n",
    "\n",
    "ooo.init()\n",
    "\n",
    "VariationalAutoencoder.about()\n",
    "DataGenerator.about()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 1.2 - The good place"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Well, we should be at IDRIS !\n",
      "We are going to use: /gpfswork/rech/mlh/uja62cb/datasets/celeba\n"
     ]
    }
   ],
   "source": [
    "place, dataset_dir = ooo.good_place( { 'GRICAD' : f'{os.getenv(\"SCRATCH_DIR\",\"\")}/PROJECTS/pr-fidle/datasets/celeba',\n",
    "                                       'IDRIS'  : f'{os.getenv(\"WORK\",\"\")}/datasets/celeba',\n",
    "                                       'HOME'   : f'{os.getenv(\"HOME\",\"\")}/datasets/celeba'} )\n",
    "\n",
    "# ---- train/test datasets\n",
    "\n",
    "train_dir    = f'{dataset_dir}/clusters-M.train'\n",
    "test_dir     = f'{dataset_dir}/clusters-M.test'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 2 - DataGenerator and validation data\n",
    "Ok, everything's perfect, now let's instantiate our generator for the entire dataset."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Data generator : 6250 batchs of 32 images, or 200000 images\n",
      "x_test : 2599 images\n"
     ]
    }
   ],
   "source": [
    "data_gen = DataGenerator(train_dir, 32, k_size=1)\n",
    "x_test   = np.load(f'{test_dir}/images-000.npy')\n",
    "\n",
    "print(f'Data generator : {len(data_gen)} batchs of {data_gen.batch_size} images, or {data_gen.dataset_size} images')\n",
    "print(f'x_test : {len(x_test)} images')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 3 - Get VAE model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model initialized.\n",
      "Outputs will be in  : ./run/CelebA.052-M.810156\n",
      "Config saved in     : ./run/CelebA.052-M.810156/models/vae_config.json\n"
   "source": [
    "tag = f'CelebA.052-M.{os.getenv(\"SLURM_JOB_ID\",\"unknown\")}'\n",
    "input_shape = (192, 160, 3)\n",
    "z_dim       = 200\n",
    "verbose     = 0\n",
    "\n",
    "encoder= [ {'type':'Conv2D',          'filters':32, 'kernel_size':(3,3), 'strides':2, 'padding':'same', 'activation':'relu'},\n",
    "           {'type':'Dropout',         'rate':0.25},\n",
    "           {'type':'Conv2D',          'filters':64, 'kernel_size':(3,3), 'strides':2, 'padding':'same', 'activation':'relu'},\n",
    "           {'type':'Dropout',         'rate':0.25},\n",
    "           {'type':'Conv2D',          'filters':64, 'kernel_size':(3,3), 'strides':2, 'padding':'same', 'activation':'relu'},\n",
    "           {'type':'Dropout',         'rate':0.25},\n",
    "           {'type':'Conv2D',          'filters':64, 'kernel_size':(3,3), 'strides':2, 'padding':'same', 'activation':'relu'},\n",
    "           {'type':'Dropout',         'rate':0.25},\n",
    "         ]\n",
    "\n",
    "decoder= [ {'type':'Conv2DTranspose', 'filters':64, 'kernel_size':(3,3), 'strides':2, 'padding':'same', 'activation':'relu'},\n",
    "           {'type':'Dropout',         'rate':0.25},\n",
    "           {'type':'Conv2DTranspose', 'filters':64, 'kernel_size':(3,3), 'strides':2, 'padding':'same', 'activation':'relu'},\n",
    "           {'type':'Dropout',         'rate':0.25},\n",
    "           {'type':'Conv2DTranspose', 'filters':32, 'kernel_size':(3,3), 'strides':2, 'padding':'same', 'activation':'relu'},\n",
    "           {'type':'Dropout',         'rate':0.25},\n",
    "           {'type':'Conv2DTranspose', 'filters':3,  'kernel_size':(3,3), 'strides':2, 'padding':'same', 'activation':'sigmoid'}\n",
    "         ]\n",
    "\n",
    "vae = modules.vae.VariationalAutoencoder(input_shape    = input_shape, \n",
    "                                         encoder_layers = encoder, \n",
    "                                         decoder_layers = decoder,\n",
    "                                         z_dim          = z_dim, \n",
    "                                         verbose        = verbose,\n",
    "                                         run_tag        = tag)\n",
    "vae.save(model=None)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 4 - Compile it"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Compiled.\n"
     ]
    }
   ],
   "source": [
    "optimizer     = tf.keras.optimizers.Adam(1e-4)\n",
    "r_loss_factor = 10000\n",
    "\n",
    "vae.compile(optimizer, r_loss_factor)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 5 - Train\n",
    "For 10 epochs, adam optimizer :  \n",
    "- Run time at IDRIS : 1299.77 sec. - 0:21:39\n",
    "- Run time at GRICAD : 2092.77 sec. - 0:34:52\n",
    "- At IDRIS with medium resolution : Train duration : 6638.61 sec. - 1:50:38"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "epochs            = 20\n",
    "initial_epoch     = 0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/20\n",
      "6250/6250 [==============================] - 342s 55ms/step - loss: 332.4792 - vae_r_loss: 282.9655 - vae_kl_loss: 49.5134 - val_loss: 235.4288 - val_vae_r_loss: 187.4334 - val_vae_kl_loss: 48.1192\n",
      "Epoch 2/20\n",
      "6250/6250 [==============================] - 337s 54ms/step - loss: 224.0962 - vae_r_loss: 166.4602 - vae_kl_loss: 57.6360 - val_loss: 210.9632 - val_vae_r_loss: 155.0925 - val_vae_kl_loss: 56.0114\n",
      "Epoch 4/20\n",
      "6250/6250 [==============================] - 333s 53ms/step - loss: 214.5463 - vae_r_loss: 155.7666 - vae_kl_loss: 58.7794 - val_loss: 203.8241 - val_vae_r_loss: 147.3248 - val_vae_kl_loss: 56.5778\n",
      "Epoch 7/20\n",
      "6250/6250 [==============================] - 327s 52ms/step - loss: 211.1459 - vae_r_loss: 152.4026 - vae_kl_loss: 58.7437 - val_loss: 201.1862 - val_vae_r_loss: 145.6906 - val_vae_kl_loss: 55.6326\n",
      "Epoch 10/20\n",
      "6250/6250 [==============================] - 335s 54ms/step - loss: 209.7628 - vae_r_loss: 151.0874 - vae_kl_loss: 58.6756 - val_loss: 202.3954 - val_vae_r_loss: 147.0956 - val_vae_kl_loss: 55.4047\n",
      "Epoch 12/20\n",
      "6250/6250 [==============================] - 333s 53ms/step - loss: 207.9830 - vae_r_loss: 149.3870 - vae_kl_loss: 58.5959 - val_loss: 198.5626 - val_vae_r_loss: 142.5848 - val_vae_kl_loss: 56.0871\n",
      "Epoch 16/20\n",
      "6250/6250 [==============================] - 330s 53ms/step - loss: 206.6382 - vae_r_loss: 148.0522 - vae_kl_loss: 58.5863 - val_loss: 197.5800 - val_vae_r_loss: 142.6799 - val_vae_kl_loss: 54.9832\n",
      "Train duration : 6638.61 sec. - 1:50:38\n"
   "source": [
    "vae.train(data_generator    = data_gen,\n",
    "          x_test            = x_test,\n",
    "          epochs            = epochs,\n",
    "          initial_epoch     = initial_epoch\n",
    "         )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "----\n",
    "That's all folks !"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}