{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<img width=\"800px\" src=\"../fidle/img/00-Fidle-header-01.svg\"></img>\n",
    "\n",
    "# <!-- TITLE --> [VAE8] - Training session for our VAE\n",
    "<!-- DESC --> Episode 4 : Training with our clustered datasets in notebook or batch mode\n",
    "<!-- AUTHOR : Jean-Luc Parouty (CNRS/SIMaP) -->\n",
    "\n",
    "## Objectives :\n",
    " - Build and train a VAE model with a large dataset in **small or medium resolution (70 to 140 GB)**\n",
    " - Understanding a more advanced programming model with **data generator**\n",
    "\n",
    "The [CelebFaces Attributes Dataset (CelebA)](http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html) contains about 200,000 images (202599,218,178,3).  \n",
    "\n",
    "## What we're going to do :\n",
    "\n",
    " - Defining a VAE model\n",
    " - Build the model\n",
    " - Train it\n",
    " - Follow the learning process with Tensorboard\n",
    "\n",
    "## Acknowledgements :\n",
    "As before, thanks to **François Chollet** who is at the base of this example.  \n",
    "See : https://keras.io/examples/generative/vae\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 1 - Init python stuff"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import sys\n",
    "\n",
    "from tensorflow import keras\n",
    "from tensorflow.keras import layers\n",
    "from tensorflow.keras.callbacks import TensorBoard\n",
    "\n",
    "from modules.models    import VAE\n",
    "from modules.layers    import SamplingLayer\n",
    "from modules.callbacks import ImagesCallback, BestModelCallback\n",
    "from modules.datagen   import DataGenerator\n",
    "\n",
    "sys.path.append('..')\n",
    "import fidle.pwk as pwk\n",
    "\n",
    "run_dir = './run/VAE8.001'\n",
    "datasets_dir = pwk.init('VAE8', run_dir)\n",
    "\n",
    "VAE.about()\n",
    "DataGenerator.about()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# To clean run_dir, uncomment and run this next line\n",
    "# ! rm -r \"$run_dir\"/images-* \"$run_dir\"/logs \"$run_dir\"/figs \"$run_dir\"/models ; rmdir \"$run_dir\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 2 - Parameters\n",
    "`scale` : With scale=1, we need 1'30s on a GPU V100 ...and >20' on a CPU !  \n",
    "`latent_dim` : 2 dimensions is small, but usefull to draw !  \n",
    "`fit_verbosity` : verbosity during training : 0 = silent, 1 = progress bar, 2 = one line per epoch  \n",
    "`loss_weights` : Our **loss function** is the weighted sum of two loss:\n",
    " - `r_loss` which measures the loss during reconstruction.  \n",
    " - `kl_loss` which measures the dispersion.  \n",
    "\n",
    "The weights are defined by: `loss_weights=[k1,k2]` where : `total_loss = k1*r_loss + k2*kl_loss`  \n",
    "In practice, a value of \\[.6,.4\\] gives good results here.\n",
    "\n",
    "\n",
    "Uncomment the right lines according to what you want."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fit_verbosity = 1\n",
    "\n",
    "# ---- For tests\n",
    "#\n",
    "scale         = 0.1\n",
    "image_size    = (128,128)\n",
    "enhanced_dir  = './data'\n",
    "latent_dim    = 300\n",
    "loss_weights  = [.6,.4]\n",
    "batch_size    = 64\n",
    "epochs        = 15\n",
    "\n",
    "# ---- Training with a full dataset\n",
    "#\n",
    "# scale         = 1.\n",
    "# image_size    = (128,128)\n",
    "# enhanced_dir  = f'{datasets_dir}/celeba/enhanced'\n",
    "# latent_dim    = 300\n",
    "# loss_weights  = [.6,.4]\n",
    "# batch_size    = 64\n",
    "# epochs        = 15\n",
    "\n",
    "# ---- Training with a full dataset of large images\n",
    "#\n",
    "# # scale         = 1.\n",
    "# image_size    = (192,160)\n",
    "# enhanced_dir  = f'{datasets_dir}/celeba/enhanced'\n",
    "# latent_dim    = 300\n",
    "# loss_weights  = [.6,.4]\n",
    "# batch_size    = 64\n",
    "# epochs        = 15"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Override parameters (batch mode) - Just forget this cell"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pwk.override('scale', 'image_size', 'enhanced_dir', 'latent_dim', 'loss_weights')\n",
    "pwk.override('batch_size', 'epochs', 'fit_verbosity')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 3 - Prepare data\n",
    "Let's instantiate our generator for the entire dataset."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 3.1 - Finding the right place"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "lx,ly      = image_size\n",
    "train_dir  = f'{enhanced_dir}/clusters-{lx}x{ly}'\n",
    "\n",
    "print('Train directory is :',train_dir)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 3.2 - Get a DataGenerator"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_gen = DataGenerator(train_dir, 32, scale=scale)\n",
    "\n",
    "print(f'Data generator is ready with : {len(data_gen)} batchs of {data_gen.batch_size} images, or {data_gen.dataset_size} images')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 4 - Build model\n",
    "Note: We conserve the geometry of our last convolutional output (shape_before_flattening) so that we can adapt the decoder to the encoder."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Encoder"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "inputs    = keras.Input(shape=(lx, ly, 3))\n",
    "x         = layers.Conv2D(32, 3, strides=2, padding=\"same\", activation=\"relu\")(inputs)\n",
    "x         = layers.Conv2D(64, 3, strides=2, padding=\"same\", activation=\"relu\")(x)\n",
    "x         = layers.Conv2D(64, 3, strides=2, padding=\"same\", activation=\"relu\")(x)\n",
    "x         = layers.Conv2D(64, 3, strides=2, padding=\"same\", activation=\"relu\")(x)\n",
    "\n",
    "shape_before_flattening = keras.backend.int_shape(x)[1:]\n",
    "\n",
    "x         = layers.Flatten()(x)\n",
    "x         = layers.Dense(512, activation=\"relu\")(x)\n",
    "\n",
    "z_mean    = layers.Dense(latent_dim, name=\"z_mean\")(x)\n",
    "z_log_var = layers.Dense(latent_dim, name=\"z_log_var\")(x)\n",
    "z         = SamplingLayer()([z_mean, z_log_var])\n",
    "\n",
    "encoder = keras.Model(inputs, [z_mean, z_log_var, z], name=\"encoder\")\n",
    "encoder.compile()\n",
    "# encoder.summary()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Decoder"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "inputs  = keras.Input(shape=(latent_dim,))\n",
    "\n",
    "x = layers.Dense(np.prod(shape_before_flattening))(inputs)\n",
    "x = layers.Reshape(shape_before_flattening)(x)\n",
    "\n",
    "x       = layers.Conv2DTranspose(64, 3, strides=2, padding=\"same\", activation=\"relu\")(x)\n",
    "x       = layers.Conv2DTranspose(64, 3, strides=2, padding=\"same\", activation=\"relu\")(x)\n",
    "x       = layers.Conv2DTranspose(64, 3, strides=2, padding=\"same\", activation=\"relu\")(x)\n",
    "x       = layers.Conv2DTranspose(32, 3, strides=2, padding=\"same\", activation=\"relu\")(x)\n",
    "outputs = layers.Conv2DTranspose(3,  3, padding=\"same\", activation=\"sigmoid\")(x)\n",
    "\n",
    "decoder = keras.Model(inputs, outputs, name=\"decoder\")\n",
    "decoder.compile()\n",
    "# decoder.summary()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### VAE\n",
    "Our loss function is the weighted sum of two values.  \n",
    "`reconstruction_loss` which measures the loss during reconstruction.  \n",
    "`kl_loss` which measures the dispersion.  \n",
    "\n",
    "The weights are defined by: `r_loss_factor` :  \n",
    "`total_loss = r_loss_factor*reconstruction_loss + (1-r_loss_factor)*kl_loss`\n",
    "\n",
    "if `r_loss_factor = 1`, the loss function includes only `reconstruction_loss`  \n",
    "if `r_loss_factor = 0`, the loss function includes only `kl_loss`  \n",
    "In practice, a value arround 0.5 gives good results here.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "vae = VAE(encoder, decoder, loss_weights)\n",
    "\n",
    "vae.compile(optimizer=keras.optimizers.Adam())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 5 - Train\n",
    "With `scale=1`, need 20' for 10 epochs on a V100 (IDRIS)  \n",
    "...on a basic CPU, may be >40 hours !"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 5.1 - Callbacks"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "x_draw,_   = data_gen[0]\n",
    "data_gen.rewind()\n",
    "\n",
    "callback_images      = ImagesCallback(x=x_draw, z_dim=latent_dim, nb_images=5, from_z=True, from_random=True, run_dir=run_dir)\n",
    "callback_bestmodel   = BestModelCallback( run_dir + '/models/best_model.h5' )\n",
    "callback_tensorboard = TensorBoard(log_dir=run_dir + '/logs', histogram_freq=1)\n",
    "\n",
    "callbacks_list = [callback_images, callback_bestmodel]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 5.2 - Train it"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pwk.chrono_start()\n",
    "\n",
    "history = vae.fit(data_gen, epochs=epochs, batch_size=batch_size, callbacks=callbacks_list, verbose=fit_verbosity)\n",
    "\n",
    "pwk.chrono_show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 6 - Training review\n",
    "### 6.1 - History"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pwk.plot_history(history,  plot={\"Loss\":['loss','r_loss', 'kl_loss']}, save_as='01-history')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 6.2 - Reconstruction during training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "images_z, images_r = callback_images.get_images( range(0,epochs,2) )\n",
    "\n",
    "pwk.subtitle('Original images :')\n",
    "pwk.plot_images(x_draw[:5], None, indices='all', columns=5, x_size=2,y_size=2, save_as='02-original')\n",
    "\n",
    "pwk.subtitle('Encoded/decoded images')\n",
    "pwk.plot_images(images_z, None, indices='all', columns=5, x_size=2,y_size=2, save_as='03-reconstruct')\n",
    "\n",
    "pwk.subtitle('Original images :')\n",
    "pwk.plot_images(x_draw[:5], None, indices='all', columns=5, x_size=2,y_size=2, save_as=None)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 6.3 - Generation (latent -> decoder) during training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pwk.subtitle('Generated images from latent space')\n",
    "pwk.plot_images(images_r, None, indices='all', columns=5, x_size=2,y_size=2, save_as='04-encoded')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pwk.end()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "---\n",
    "<img width=\"80px\" src=\"../fidle/img/00-Fidle-logo-01.svg\"></img>"
   ]
  }
 ],
 "metadata": {
  "interpreter": {
   "hash": "8e38643e33497db9a306e3f311fa98cb1e65371278ca73ee4ea0c76aa5a4f387"
  },
  "kernelspec": {
   "display_name": "Python 3.9.7 64-bit ('fidle-cpu': conda)",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}