{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<img width=\"800px\" src=\"../fidle/img/00-Fidle-header-01.svg\"></img>\n",
    "\n",
    "# *** OBSOLETE *** [VAE1] - Variational AutoEncoder (VAE) with MNIST\n",
    "Episode 1 : Model construction and Training\n",
    "\n",
    "<!-- AUTHOR : Jean-Luc Parouty (CNRS/SIMaP) -->\n",
    "\n",
    "## Objectives :\n",
    " - Understanding and implementing a **variational autoencoder** neurals network (VAE)\n",
    " - Understanding a more **advanced programming model**\n",
    "\n",
    "The calculation needs being important, it is preferable to use a very simple dataset such as MNIST to start with.\n",
    "\n",
    "## What we're going to do :\n",
    "\n",
    " - Defining a VAE model\n",
    " - Build the model\n",
    " - Train it\n",
    " - Follow the learning process with Tensorboard\n",
    "\n",
    "----\n",
    "## Bug Note :\n",
    "**Works in tf 2.0, but not in 2.2/2.3**  \n",
    "\n",
    "See :\n",
    " - https://github.com/tensorflow/tensorflow/issues/34944  \n",
    " - https://github.com/tensorflow/probability/issues/519  \n",
    "\n",
    "Bypass :\n",
    " - Use tf 2.0\n",
    " - Add `tf.config.experimental_run_functions_eagerly(True)` before compilation...  \n",
    "Works fine in versions 2.2, 2.3 but with horrible perf. (7s -> 1'50s)\n",
    "----"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 1 - Init python stuff"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "FIDLE 2020 - Variational AutoEncoder (VAE)\n",
      "TensorFlow version   : 2.0.0\n",
      "VAE version          : 1.28\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "import sys, importlib\n",
    "\n",
    "import modules.vae\n",
    "import modules.loader_MNIST\n",
    "\n",
    "from modules.vae          import VariationalAutoencoder\n",
    "from modules.loader_MNIST import Loader_MNIST\n",
    "\n",
    "VariationalAutoencoder.about()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 2 - Get data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Dataset loaded.\n",
      "Normalized.\n",
      "Reshaped to (60000, 28, 28, 1)\n"
     ]
    }
   ],
   "source": [
    "(x_train, y_train), (x_test, y_test) = Loader_MNIST.load()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 3 - Get VAE model\n",
    "Nous allons instancier notre modèle VAE.  \n",
    "Ce dernier est défini avec une classe python pour alléger notre code.  \n",
    "La description de nos deux réseaux est donnée en paramètre.  \n",
    "Notre modèle sera sauvegardé dans le dossier : ./run/<tag>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tag = 'MNIST.001'\n",
    "\n",
    "input_shape = (28,28,1)\n",
    "z_dim       = 2\n",
    "verbose     = 1\n",
    "\n",
    "encoder= [ {'type':'Conv2D',          'filters':32, 'kernel_size':(3,3), 'strides':1, 'padding':'same', 'activation':'relu'},\n",
    "           {'type':'Conv2D',          'filters':64, 'kernel_size':(3,3), 'strides':2, 'padding':'same', 'activation':'relu'},\n",
    "           {'type':'Conv2D',          'filters':64, 'kernel_size':(3,3), 'strides':2, 'padding':'same', 'activation':'relu'},\n",
    "           {'type':'Conv2D',          'filters':64, 'kernel_size':(3,3), 'strides':1, 'padding':'same', 'activation':'relu'}\n",
    "         ]\n",
    "\n",
    "decoder= [ {'type':'Conv2DTranspose', 'filters':64, 'kernel_size':(3,3), 'strides':1, 'padding':'same', 'activation':'relu'},\n",
    "           {'type':'Conv2DTranspose', 'filters':64, 'kernel_size':(3,3), 'strides':2, 'padding':'same', 'activation':'relu'},\n",
    "           {'type':'Conv2DTranspose', 'filters':32, 'kernel_size':(3,3), 'strides':2, 'padding':'same', 'activation':'relu'},\n",
    "           {'type':'Conv2DTranspose', 'filters':1,  'kernel_size':(3,3), 'strides':1, 'padding':'same', 'activation':'sigmoid'}\n",
    "         ]\n",
    "\n",
    "vae = modules.vae.VariationalAutoencoder(input_shape    = input_shape, \n",
    "                                         encoder_layers = encoder, \n",
    "                                         decoder_layers = decoder,\n",
    "                                         z_dim          = z_dim, \n",
    "                                         verbose        = verbose,\n",
    "                                         run_tag        = tag)\n",
    "vae.save(model=None)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 4 - Compile it"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "r_loss_factor = 1000\n",
    "\n",
    "vae.compile( optimizer='adam', r_loss_factor=r_loss_factor)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 5 - Train"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_size        = 100\n",
    "epochs            = 100\n",
    "initial_epoch     = 0\n",
    "k_size            = 1      # 1 mean using 100% of the dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "vae.train(x_train,\n",
    "          x_test,\n",
    "          batch_size        = batch_size, \n",
    "          epochs            = epochs,\n",
    "          initial_epoch     = initial_epoch,\n",
    "          k_size            = k_size\n",
    "         )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "---\n",
    "<img width=\"80px\" src=\"../fidle/img/00-Fidle-logo-01.svg\"></img>"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}