{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<img width=\"800px\" src=\"../fidle/img/00-Fidle-header-01.svg\"></img>\n",
    "\n",
    "# <!-- TITLE --> [VAE8] - Training session for our VAE\n",
    "<!-- DESC --> Episode 4 : Training with our clustered datasets in notebook or batch mode\n",
    "<!-- AUTHOR : Jean-Luc Parouty (CNRS/SIMaP) -->\n",
    "\n",
    "## Objectives :\n",
    " - Build and train a VAE model with a large dataset in **small or medium resolution (70 to 140 GB)**\n",
    " - Understanding a more advanced programming model with **data generator**\n",
    "\n",
    "The [CelebFaces Attributes Dataset (CelebA)](http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html) contains about 200,000 images (202599,218,178,3).  \n",
    "\n",
    "## What we're going to do :\n",
    "\n",
    " - Defining a VAE model\n",
    " - Build the model\n",
    " - Train it\n",
    " - Follow the learning process with Tensorboard\n",
    "\n",
    "## Acknowledgements :\n",
    "As before, thanks to **François Chollet** who is at the base of this example.  \n",
    "See : https://keras.io/examples/generative/vae\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 1 - Init python stuff"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from skimage import io\n",
    "\n",
    "import tensorflow as tf\n",
    "from tensorflow import keras\n",
    "from tensorflow.keras import layers\n",
    "from tensorflow.keras.callbacks import ModelCheckpoint, TensorBoard\n",
    "\n",
    "import os,sys,json,time,datetime\n",
    "from IPython.display import display,Image,Markdown,HTML\n",
    "\n",
    "from modules.data_generator import DataGenerator\n",
    "from modules.VAE            import VAE, Sampling\n",
    "from modules.callbacks      import ImagesCallback, BestModelCallback\n",
    "\n",
    "sys.path.append('..')\n",
    "import fidle.pwk as pwk\n",
    "\n",
    "run_dir = './run/CelebA.001'                  # Output directory\n",
    "datasets_dir = pwk.init('VAE8', run_dir)\n",
    "\n",
    "VAE.about()\n",
    "DataGenerator.about()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# To clean run_dir, uncomment and run this next line\n",
    "# ! rm -r \"$run_dir\"/images-* \"$run_dir\"/logs \"$run_dir\"/figs \"$run_dir\"/models ; rmdir \"$run_dir\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 2 - Parameters\n",
    "Uncomment the right lines according to what you want."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# ---- For tests\n",
    "scale         = 0.3\n",
    "image_size    = (128,128)\n",
    "enhanced_dir  = './data'\n",
    "latent_dim    = 300\n",
    "r_loss_factor = 0.6\n",
    "batch_size    = 64\n",
    "epochs        = 15\n",
    "\n",
    "# ---- Training with a full dataset\n",
    "# scale         = 1.\n",
    "# image_size    = (128,128)\n",
    "# enhanced_dir  = f'{datasets_dir}/celeba/enhanced'\n",
    "# latent_dim    = 300\n",
    "# r_loss_factor = 0.6\n",
    "# batch_size    = 64\n",
    "# epochs        = 15\n",
    "\n",
    "# ---- Training with a full dataset of large images\n",
    "# scale         = 1.\n",
    "# image_size    = (192,160)\n",
    "# enhanced_dir  = f'{datasets_dir}/celeba/enhanced'\n",
    "# latent_dim    = 300\n",
    "# r_loss_factor = 0.6\n",
    "# batch_size    = 64\n",
    "# epochs        = 15"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Override parameters (batch mode) - Just forget this cell"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pwk.override('scale', 'image_size', 'enhanced_dir', 'latent_dim', 'r_loss_factor', 'batch_size', 'epochs')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 3 - Get some data\n",
    "Let's instantiate our generator for the entire dataset."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 3.1 - Finding the right place"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "lx,ly      = image_size\n",
    "train_dir  = f'{enhanced_dir}/clusters-{lx}x{ly}'\n",
    "\n",
    "print('Train directory is :',train_dir)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 3.2 - Get a DataGenerator"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_gen = DataGenerator(train_dir, 32, k_size=scale)\n",
    "\n",
    "print(f'Data generator is ready with : {len(data_gen)} batchs of {data_gen.batch_size} images, or {data_gen.dataset_size} images')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 4 - Build model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Encoder"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "inputs    = keras.Input(shape=(lx, ly, 3))\n",
    "x         = layers.Conv2D(32, 3, strides=2, padding=\"same\", activation=\"relu\")(inputs)\n",
    "x         = layers.Conv2D(64, 3, strides=2, padding=\"same\", activation=\"relu\")(x)\n",
    "x         = layers.Conv2D(64, 3, strides=2, padding=\"same\", activation=\"relu\")(x)\n",
    "x         = layers.Conv2D(64, 3, strides=2, padding=\"same\", activation=\"relu\")(x)\n",
    "\n",
    "shape_before_flattening = keras.backend.int_shape(x)[1:]\n",
    "\n",
    "x         = layers.Flatten()(x)\n",
    "x         = layers.Dense(512, activation=\"relu\")(x)\n",
    "\n",
    "z_mean    = layers.Dense(latent_dim, name=\"z_mean\")(x)\n",
    "z_log_var = layers.Dense(latent_dim, name=\"z_log_var\")(x)\n",
    "z         = Sampling()([z_mean, z_log_var])\n",
    "\n",
    "encoder = keras.Model(inputs, [z_mean, z_log_var, z], name=\"encoder\")\n",
    "encoder.compile()\n",
    "# encoder.summary()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Decoder"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "inputs  = keras.Input(shape=(latent_dim,))\n",
    "\n",
    "x = layers.Dense(np.prod(shape_before_flattening))(inputs)\n",
    "x = layers.Reshape(shape_before_flattening)(x)\n",
    "\n",
    "x       = layers.Conv2DTranspose(64, 3, strides=2, padding=\"same\", activation=\"relu\")(x)\n",
    "x       = layers.Conv2DTranspose(64, 3, strides=2, padding=\"same\", activation=\"relu\")(x)\n",
    "x       = layers.Conv2DTranspose(64, 3, strides=2, padding=\"same\", activation=\"relu\")(x)\n",
    "x       = layers.Conv2DTranspose(32, 3, strides=2, padding=\"same\", activation=\"relu\")(x)\n",
    "outputs = layers.Conv2DTranspose(3,  3, padding=\"same\", activation=\"sigmoid\")(x)\n",
    "\n",
    "decoder = keras.Model(inputs, outputs, name=\"decoder\")\n",
    "decoder.compile()\n",
    "# decoder.summary()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### VAE\n",
    "Our loss function is the weighted sum of two values.  \n",
    "`reconstruction_loss` which measures the loss during reconstruction.  \n",
    "`kl_loss` which measures the dispersion.  \n",
    "\n",
    "The weights are defined by: `r_loss_factor` :  \n",
    "`total_loss = r_loss_factor*reconstruction_loss + (1-r_loss_factor)*kl_loss`\n",
    "\n",
    "if `r_loss_factor = 1`, the loss function includes only `reconstruction_loss`  \n",
    "if `r_loss_factor = 0`, the loss function includes only `kl_loss`  \n",
    "In practice, a value arround 0.5 gives good results here.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "vae = VAE(encoder, decoder, r_loss_factor)\n",
    "\n",
    "vae.compile(optimizer=keras.optimizers.Adam())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 5 - Train\n",
    "With `scale=1`, need 20' for 10 epochs on a V100 (IDRIS)  \n",
    "...on a basic CPU, may be >40 hours !"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 5.1 - Callbacks"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "x_draw,_   = data_gen[0]\n",
    "data_gen.rewind()\n",
    "\n",
    "# ---- Callback : Images encoded\n",
    "pwk.mkdir(run_dir + '/images-encoded')\n",
    "filename = run_dir + '/images-encoded/image-{epoch:03d}-{i:02d}.jpg'\n",
    "callback_images1 = ImagesCallback(filename, x=x_draw[:5], encoder=encoder,decoder=decoder)\n",
    "\n",
    "# ---- Callback : Images generated\n",
    "pwk.mkdir(run_dir + '/images-generated')\n",
    "filename = run_dir + '/images-generated/image-{epoch:03d}-{i:02d}.jpg'\n",
    "callback_images2 = ImagesCallback(filename, x=None, nb_images=5, z_dim=latent_dim, encoder=encoder,decoder=decoder)          \n",
    "\n",
    "# ---- Callback : Best model\n",
    "pwk.mkdir(run_dir + '/models')\n",
    "filename = run_dir + '/models/best_model'\n",
    "callback_bestmodel = BestModelCallback(filename)\n",
    "\n",
    "# ---- Callback tensorboard\n",
    "dirname = run_dir + '/logs'\n",
    "callback_tensorboard = TensorBoard(log_dir=dirname, histogram_freq=1)\n",
    "\n",
    "callbacks_list = [callback_images1, callback_images2, callback_bestmodel, callback_tensorboard]\n",
    "callbacks_list = [callback_images1, callback_images2, callback_bestmodel]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 5.2 - Train it"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pwk.chrono_start()\n",
    "\n",
    "history = vae.fit(data_gen, epochs=epochs, batch_size=batch_size, callbacks=callbacks_list)\n",
    "\n",
    "pwk.chrono_show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 6 - About our training session\n",
    "### 6.1 - History"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pwk.plot_history(history,  plot={\"Loss\":['loss','r_loss', 'kl_loss']}, save_as='01-history')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 6.2 - Reconstruction (input -> encoder -> decoder)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "imgs=[]\n",
    "labels=[]\n",
    "for epoch in range(1,epochs,1):\n",
    "    for i in range(5):\n",
    "        filename = f'{run_dir}/images-encoded/image-{epoch:03d}-{i:02d}.jpg'.format(epoch=epoch, i=i)\n",
    "        img      = io.imread(filename)\n",
    "        imgs.append(img)\n",
    "        \n",
    "\n",
    "pwk.subtitle('Original images :')\n",
    "pwk.plot_images(x_draw[:5], None, indices='all', columns=5, x_size=2,y_size=2, save_as='02-original')\n",
    "\n",
    "pwk.subtitle('Encoded/decoded images')\n",
    "pwk.plot_images(imgs, None, indices='all', columns=5, x_size=2,y_size=2, save_as='03-reconstruct')\n",
    "\n",
    "pwk.subtitle('Original images :')\n",
    "pwk.plot_images(x_draw[:5], None, indices='all', columns=5, x_size=2,y_size=2, save_as=None)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 6.3 Generation (latent -> decoder)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "imgs=[]\n",
    "labels=[]\n",
    "for epoch in range(1,epochs,1):\n",
    "    for i in range(5):\n",
    "        filename = f'{run_dir}/images-generated/image-{epoch:03d}-{i:02d}.jpg'.format(epoch=epoch, i=i)\n",
    "        img      = io.imread(filename)\n",
    "        imgs.append(img)\n",
    "        \n",
    "pwk.subtitle('Generated images from latent space')\n",
    "pwk.plot_images(imgs, None, indices='all', columns=5, x_size=2,y_size=2, save_as='04-encoded')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pwk.end()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "---\n",
    "<img width=\"80px\" src=\"../fidle/img/00-Fidle-logo-01.svg\"></img>"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}