{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<img width=\"800px\" src=\"../fidle/img/00-Fidle-header-01.svg\"></img>\n",
    "\n",
    "# <!-- TITLE --> [VAE1] - Variational AutoEncoder (VAE) with MNIST\n",
    "<!-- DESC --> Episode 1 : Model construction and Training\n",
    "\n",
    "<!-- AUTHOR : Jean-Luc Parouty (CNRS/SIMaP) -->\n",
    "\n",
    "## Objectives :\n",
    " - Understanding and implementing a **variational autoencoder** neurals network (VAE)\n",
    " - Understanding a more **advanced programming model**\n",
    "\n",
    "The calculation needs being important, it is preferable to use a very simple dataset such as MNIST to start with.\n",
    "\n",
    "## What we're going to do :\n",
    "\n",
    " - Defining a VAE model\n",
    " - Build the model\n",
    " - Train it\n",
    " - Follow the learning process with Tensorboard"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 1 - Init python stuff"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "FIDLE 2020 - Variational AutoEncoder (VAE)\n",
      "TensorFlow version   : 2.0.0\n",
      "VAE version          : 1.28\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "import sys, importlib\n",
    "\n",
    "import modules.vae\n",
    "import modules.loader_MNIST\n",
    "\n",
    "from modules.vae          import VariationalAutoencoder\n",
    "from modules.loader_MNIST import Loader_MNIST\n",
    "\n",
    "VariationalAutoencoder.about()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 2 - Get data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Dataset loaded.\n",
      "Normalized.\n",
      "Reshaped to (60000, 28, 28, 1)\n"
     ]
    }
   ],
   "source": [
    "(x_train, y_train), (x_test, y_test) = Loader_MNIST.load()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 3 - Get VAE model\n",
    "Nous allons instancier notre modèle VAE.  \n",
    "Ce dernier est défini avec une classe python pour alléger notre code.  \n",
    "La description de nos deux réseaux est donnée en paramètre.  \n",
    "Notre modèle sera sauvegardé dans le dossier : ./run/<tag>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model initialized.\n",
      "Outputs will be in  : ./run/MNIST.001\n",
      "\n",
      " ---------- Encoder -------------------------------------------------- \n",
      "\n",
      "Model: \"model_13\"\n",
      "__________________________________________________________________________________________________\n",
      "Layer (type)                    Output Shape         Param #     Connected to                     \n",
      "==================================================================================================\n",
      "encoder_input (InputLayer)      [(None, 28, 28, 1)]  0                                            \n",
      "__________________________________________________________________________________________________\n",
      "conv2d_12 (Conv2D)              (None, 28, 28, 32)   320         encoder_input[0][0]              \n",
      "__________________________________________________________________________________________________\n",
      "conv2d_13 (Conv2D)              (None, 14, 14, 64)   18496       conv2d_12[0][0]                  \n",
      "__________________________________________________________________________________________________\n",
      "conv2d_14 (Conv2D)              (None, 7, 7, 64)     36928       conv2d_13[0][0]                  \n",
      "__________________________________________________________________________________________________\n",
      "conv2d_15 (Conv2D)              (None, 7, 7, 64)     36928       conv2d_14[0][0]                  \n",
      "__________________________________________________________________________________________________\n",
      "flatten_3 (Flatten)             (None, 3136)         0           conv2d_15[0][0]                  \n",
      "__________________________________________________________________________________________________\n",
      "mu (Dense)                      (None, 2)            6274        flatten_3[0][0]                  \n",
      "__________________________________________________________________________________________________\n",
      "log_var (Dense)                 (None, 2)            6274        flatten_3[0][0]                  \n",
      "__________________________________________________________________________________________________\n",
      "encoder_output (Lambda)         (None, 2)            0           mu[0][0]                         \n",
      "                                                                 log_var[0][0]                    \n",
      "==================================================================================================\n",
      "Total params: 105,220\n",
      "Trainable params: 105,220\n",
      "Non-trainable params: 0\n",
      "__________________________________________________________________________________________________\n",
      "\n",
      " ---------- Encoder -------------------------------------------------- \n",
      "\n",
      "Model: \"model_14\"\n",
      "_________________________________________________________________\n",
      "Layer (type)                 Output Shape              Param #   \n",
      "=================================================================\n",
      "decoder_input (InputLayer)   [(None, 2)]               0         \n",
      "_________________________________________________________________\n",
      "dense_3 (Dense)              (None, 3136)              9408      \n",
      "_________________________________________________________________\n",
      "reshape_3 (Reshape)          (None, 7, 7, 64)          0         \n",
      "_________________________________________________________________\n",
      "conv2d_transpose_12 (Conv2DT (None, 7, 7, 64)          36928     \n",
      "_________________________________________________________________\n",
      "conv2d_transpose_13 (Conv2DT (None, 14, 14, 64)        36928     \n",
      "_________________________________________________________________\n",
      "conv2d_transpose_14 (Conv2DT (None, 28, 28, 32)        18464     \n",
      "_________________________________________________________________\n",
      "conv2d_transpose_15 (Conv2DT (None, 28, 28, 1)         289       \n",
      "=================================================================\n",
      "Total params: 102,017\n",
      "Trainable params: 102,017\n",
      "Non-trainable params: 0\n",
      "_________________________________________________________________\n",
      "Config saved in     : ./run/MNIST.001/models/vae_config.json\n"
     ]
    }
   ],
   "source": [
    "tag = 'MNIST.001'\n",
    "\n",
    "input_shape = (28,28,1)\n",
    "z_dim       = 2\n",
    "verbose     = 1\n",
    "\n",
    "encoder= [ {'type':'Conv2D',          'filters':32, 'kernel_size':(3,3), 'strides':1, 'padding':'same', 'activation':'relu'},\n",
    "           {'type':'Conv2D',          'filters':64, 'kernel_size':(3,3), 'strides':2, 'padding':'same', 'activation':'relu'},\n",
    "           {'type':'Conv2D',          'filters':64, 'kernel_size':(3,3), 'strides':2, 'padding':'same', 'activation':'relu'},\n",
    "           {'type':'Conv2D',          'filters':64, 'kernel_size':(3,3), 'strides':1, 'padding':'same', 'activation':'relu'}\n",
    "         ]\n",
    "\n",
    "decoder= [ {'type':'Conv2DTranspose', 'filters':64, 'kernel_size':(3,3), 'strides':1, 'padding':'same', 'activation':'relu'},\n",
    "           {'type':'Conv2DTranspose', 'filters':64, 'kernel_size':(3,3), 'strides':2, 'padding':'same', 'activation':'relu'},\n",
    "           {'type':'Conv2DTranspose', 'filters':32, 'kernel_size':(3,3), 'strides':2, 'padding':'same', 'activation':'relu'},\n",
    "           {'type':'Conv2DTranspose', 'filters':1,  'kernel_size':(3,3), 'strides':1, 'padding':'same', 'activation':'sigmoid'}\n",
    "         ]\n",
    "\n",
    "vae = modules.vae.VariationalAutoencoder(input_shape    = input_shape, \n",
    "                                         encoder_layers = encoder, \n",
    "                                         decoder_layers = decoder,\n",
    "                                         z_dim          = z_dim, \n",
    "                                         verbose        = verbose,\n",
    "                                         run_tag        = tag)\n",
    "vae.save(model=None)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 4 - Compile it"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Compiled.\n"
     ]
    }
   ],
   "source": [
    "r_loss_factor = 1000\n",
    "\n",
    "vae.compile( optimizer='adam', r_loss_factor=r_loss_factor)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 5 - Train"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_size        = 100\n",
    "epochs            = 100\n",
    "initial_epoch     = 0\n",
    "k_size            = 1      # 1 mean using 100% of the dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train on 60000 samples, validate on 10000 samples\n",
      "Epoch 1/100\n",
      "  100/60000 [..............................] - ETA: 1:16:33 - loss: 231.5715 - vae_r_loss: 231.5706 - vae_kl_loss: 8.8929e-04WARNING:tensorflow:Method (on_train_batch_end) is slow compared to the batch update (0.261125). Check your callbacks.\n",
      "60000/60000 [==============================] - 16s 259us/sample - loss: 63.3479 - vae_r_loss: 60.5303 - vae_kl_loss: 2.8176 - val_loss: 52.8295 - val_vae_r_loss: 49.3652 - val_vae_kl_loss: 3.4643\n",
      "Epoch 2/100\n",
      "60000/60000 [==============================] - 8s 126us/sample - loss: 50.9248 - vae_r_loss: 46.7790 - vae_kl_loss: 4.1458 - val_loss: 49.8544 - val_vae_r_loss: 45.4392 - val_vae_kl_loss: 4.4152\n",
      "Epoch 3/100\n",
      "60000/60000 [==============================] - 8s 127us/sample - loss: 48.9337 - vae_r_loss: 44.4075 - vae_kl_loss: 4.5262 - val_loss: 48.1853 - val_vae_r_loss: 43.6399 - val_vae_kl_loss: 4.5454\n",
      "Epoch 4/100\n",
      "60000/60000 [==============================] - 8s 128us/sample - loss: 47.8006 - vae_r_loss: 43.0700 - vae_kl_loss: 4.7306 - val_loss: 47.6048 - val_vae_r_loss: 42.8379 - val_vae_kl_loss: 4.7669\n",
      "Epoch 5/100\n",
      "60000/60000 [==============================] - 8s 129us/sample - loss: 47.1728 - vae_r_loss: 42.3272 - vae_kl_loss: 4.8456 - val_loss: 47.1257 - val_vae_r_loss: 42.5182 - val_vae_kl_loss: 4.6075\n",
      "Epoch 6/100\n",
      "60000/60000 [==============================] - 8s 128us/sample - loss: 46.6197 - vae_r_loss: 41.6877 - vae_kl_loss: 4.9320 - val_loss: 46.6778 - val_vae_r_loss: 41.8177 - val_vae_kl_loss: 4.8601\n",
      "Epoch 7/100\n",
      "60000/60000 [==============================] - 8s 127us/sample - loss: 46.2559 - vae_r_loss: 41.2509 - vae_kl_loss: 5.0050 - val_loss: 46.8471 - val_vae_r_loss: 41.7164 - val_vae_kl_loss: 5.1308\n",
      "Epoch 8/100\n",
      "60000/60000 [==============================] - 8s 129us/sample - loss: 45.9705 - vae_r_loss: 40.9047 - vae_kl_loss: 5.0658 - val_loss: 46.1138 - val_vae_r_loss: 40.8994 - val_vae_kl_loss: 5.2144\n",
      "Epoch 9/100\n",
      "60000/60000 [==============================] - 8s 128us/sample - loss: 45.7034 - vae_r_loss: 40.5799 - vae_kl_loss: 5.1235 - val_loss: 45.9027 - val_vae_r_loss: 40.6272 - val_vae_kl_loss: 5.2755\n",
      "Epoch 10/100\n",
      "60000/60000 [==============================] - 8s 127us/sample - loss: 45.4206 - vae_r_loss: 40.2416 - vae_kl_loss: 5.1790 - val_loss: 45.8569 - val_vae_r_loss: 40.7173 - val_vae_kl_loss: 5.1396\n",
      "Epoch 11/100\n",
      "60000/60000 [==============================] - 7s 124us/sample - loss: 45.2388 - vae_r_loss: 40.0393 - vae_kl_loss: 5.1995 - val_loss: 45.5438 - val_vae_r_loss: 40.2990 - val_vae_kl_loss: 5.2448\n",
      "Epoch 12/100\n",
      "60000/60000 [==============================] - 7s 124us/sample - loss: 45.0701 - vae_r_loss: 39.8384 - vae_kl_loss: 5.2317 - val_loss: 45.1382 - val_vae_r_loss: 39.8545 - val_vae_kl_loss: 5.2838\n",
      "Epoch 13/100\n",
      "60000/60000 [==============================] - 6s 102us/sample - loss: 44.9229 - vae_r_loss: 39.6576 - vae_kl_loss: 5.2653 - val_loss: 45.2182 - val_vae_r_loss: 39.7051 - val_vae_kl_loss: 5.5130\n",
      "Epoch 14/100\n",
      "60000/60000 [==============================] - 6s 101us/sample - loss: 44.7520 - vae_r_loss: 39.4462 - vae_kl_loss: 5.3058 - val_loss: 44.9645 - val_vae_r_loss: 39.6967 - val_vae_kl_loss: 5.2678\n",
      "Epoch 15/100\n",
      "60000/60000 [==============================] - 7s 112us/sample - loss: 44.6182 - vae_r_loss: 39.2917 - vae_kl_loss: 5.3266 - val_loss: 45.1804 - val_vae_r_loss: 39.8132 - val_vae_kl_loss: 5.3673\n",
      "Epoch 16/100\n",
      "60000/60000 [==============================] - 8s 126us/sample - loss: 44.5127 - vae_r_loss: 39.1662 - vae_kl_loss: 5.3465 - val_loss: 44.7445 - val_vae_r_loss: 39.5578 - val_vae_kl_loss: 5.1867\n",
      "Epoch 17/100\n",
      "60000/60000 [==============================] - 7s 122us/sample - loss: 44.3639 - vae_r_loss: 38.9776 - vae_kl_loss: 5.3864 - val_loss: 45.0144 - val_vae_r_loss: 39.4877 - val_vae_kl_loss: 5.5267\n",
      "Epoch 18/100\n",
      "60000/60000 [==============================] - 8s 131us/sample - loss: 44.2794 - vae_r_loss: 38.8709 - vae_kl_loss: 5.4085 - val_loss: 44.7394 - val_vae_r_loss: 39.3797 - val_vae_kl_loss: 5.3597\n",
      "Epoch 19/100\n",
      "60000/60000 [==============================] - 8s 127us/sample - loss: 44.1593 - vae_r_loss: 38.7339 - vae_kl_loss: 5.4254 - val_loss: 44.8999 - val_vae_r_loss: 39.5979 - val_vae_kl_loss: 5.3020\n",
      "Epoch 20/100\n",
      "60000/60000 [==============================] - 7s 124us/sample - loss: 44.0882 - vae_r_loss: 38.6534 - vae_kl_loss: 5.4348 - val_loss: 44.5456 - val_vae_r_loss: 39.1507 - val_vae_kl_loss: 5.3949\n",
      "Epoch 21/100\n",
      "60000/60000 [==============================] - 8s 125us/sample - loss: 43.9450 - vae_r_loss: 38.4855 - vae_kl_loss: 5.4596 - val_loss: 44.6002 - val_vae_r_loss: 39.1242 - val_vae_kl_loss: 5.4760\n",
      "Epoch 22/100\n",
      "60000/60000 [==============================] - 8s 126us/sample - loss: 43.9054 - vae_r_loss: 38.4365 - vae_kl_loss: 5.4689 - val_loss: 44.7089 - val_vae_r_loss: 39.3456 - val_vae_kl_loss: 5.3633\n",
      "Epoch 23/100\n",
      "60000/60000 [==============================] - 8s 126us/sample - loss: 43.8545 - vae_r_loss: 38.3684 - vae_kl_loss: 5.4860 - val_loss: 44.4804 - val_vae_r_loss: 39.0637 - val_vae_kl_loss: 5.4167\n",
      "Epoch 24/100\n",
      "60000/60000 [==============================] - 8s 129us/sample - loss: 43.7724 - vae_r_loss: 38.2644 - vae_kl_loss: 5.5080 - val_loss: 44.2626 - val_vae_r_loss: 38.7498 - val_vae_kl_loss: 5.5128\n",
      "Epoch 25/100\n",
      "60000/60000 [==============================] - 7s 120us/sample - loss: 43.6945 - vae_r_loss: 38.1870 - vae_kl_loss: 5.5075 - val_loss: 44.5109 - val_vae_r_loss: 38.9870 - val_vae_kl_loss: 5.5240\n",
      "Epoch 26/100\n",
      "60000/60000 [==============================] - 6s 103us/sample - loss: 43.6235 - vae_r_loss: 38.0840 - vae_kl_loss: 5.5396 - val_loss: 44.4880 - val_vae_r_loss: 38.9770 - val_vae_kl_loss: 5.5110\n",
      "Epoch 27/100\n",
      "60000/60000 [==============================] - 6s 104us/sample - loss: 43.5726 - vae_r_loss: 38.0224 - vae_kl_loss: 5.5502 - val_loss: 44.3887 - val_vae_r_loss: 39.0129 - val_vae_kl_loss: 5.3758\n",
      "Epoch 28/100\n",
      "60000/60000 [==============================] - 7s 125us/sample - loss: 43.4963 - vae_r_loss: 37.9367 - vae_kl_loss: 5.5596 - val_loss: 44.2672 - val_vae_r_loss: 38.7244 - val_vae_kl_loss: 5.5427\n",
      "Epoch 29/100\n",
      "60000/60000 [==============================] - 7s 125us/sample - loss: 43.4534 - vae_r_loss: 37.8870 - vae_kl_loss: 5.5663 - val_loss: 44.2616 - val_vae_r_loss: 38.6397 - val_vae_kl_loss: 5.6219\n",
      "Epoch 30/100\n",
      "60000/60000 [==============================] - 8s 125us/sample - loss: 43.4108 - vae_r_loss: 37.8235 - vae_kl_loss: 5.5873 - val_loss: 44.0783 - val_vae_r_loss: 38.4805 - val_vae_kl_loss: 5.5978\n",
      "Epoch 31/100\n",
      "60000/60000 [==============================] - 8s 127us/sample - loss: 43.3281 - vae_r_loss: 37.7423 - vae_kl_loss: 5.5858 - val_loss: 44.2450 - val_vae_r_loss: 38.6322 - val_vae_kl_loss: 5.6128\n",
      "Epoch 32/100\n",
      "60000/60000 [==============================] - 8s 126us/sample - loss: 43.3066 - vae_r_loss: 37.6978 - vae_kl_loss: 5.6089 - val_loss: 44.1004 - val_vae_r_loss: 38.3046 - val_vae_kl_loss: 5.7958\n",
      "Epoch 33/100\n",
      "60000/60000 [==============================] - 7s 124us/sample - loss: 43.2654 - vae_r_loss: 37.6541 - vae_kl_loss: 5.6113 - val_loss: 44.0908 - val_vae_r_loss: 38.7236 - val_vae_kl_loss: 5.3672\n",
      "Epoch 34/100\n",
      "60000/60000 [==============================] - 8s 128us/sample - loss: 43.2006 - vae_r_loss: 37.5831 - vae_kl_loss: 5.6176 - val_loss: 44.3048 - val_vae_r_loss: 38.5594 - val_vae_kl_loss: 5.7453\n",
      "Epoch 35/100\n",
      "60000/60000 [==============================] - 8s 126us/sample - loss: 43.1657 - vae_r_loss: 37.5287 - vae_kl_loss: 5.6370 - val_loss: 44.3178 - val_vae_r_loss: 38.7578 - val_vae_kl_loss: 5.5600\n",
      "Epoch 36/100\n",
      "60000/60000 [==============================] - 8s 126us/sample - loss: 43.1199 - vae_r_loss: 37.4728 - vae_kl_loss: 5.6471 - val_loss: 43.9947 - val_vae_r_loss: 38.4591 - val_vae_kl_loss: 5.5356\n",
      "Epoch 37/100\n",
      "60000/60000 [==============================] - 8s 126us/sample - loss: 43.0550 - vae_r_loss: 37.4052 - vae_kl_loss: 5.6499 - val_loss: 44.1075 - val_vae_r_loss: 38.4646 - val_vae_kl_loss: 5.6429\n",
      "Epoch 38/100\n",
      "60000/60000 [==============================] - 7s 125us/sample - loss: 43.0274 - vae_r_loss: 37.3612 - vae_kl_loss: 5.6662 - val_loss: 44.1100 - val_vae_r_loss: 38.3107 - val_vae_kl_loss: 5.7994\n",
      "Epoch 39/100\n",
      "60000/60000 [==============================] - 8s 126us/sample - loss: 43.0046 - vae_r_loss: 37.3353 - vae_kl_loss: 5.6693 - val_loss: 43.9765 - val_vae_r_loss: 38.1482 - val_vae_kl_loss: 5.8284\n",
      "Epoch 40/100\n",
      "60000/60000 [==============================] - 7s 124us/sample - loss: 42.9831 - vae_r_loss: 37.2976 - vae_kl_loss: 5.6855 - val_loss: 44.1622 - val_vae_r_loss: 38.5135 - val_vae_kl_loss: 5.6488\n",
      "Epoch 41/100\n",
      "60000/60000 [==============================] - 8s 126us/sample - loss: 42.9614 - vae_r_loss: 37.2653 - vae_kl_loss: 5.6961 - val_loss: 43.9546 - val_vae_r_loss: 38.3111 - val_vae_kl_loss: 5.6435\n",
      "Epoch 42/100\n",
      "60000/60000 [==============================] - 8s 126us/sample - loss: 42.8949 - vae_r_loss: 37.1996 - vae_kl_loss: 5.6953 - val_loss: 44.0486 - val_vae_r_loss: 38.4427 - val_vae_kl_loss: 5.6059\n",
      "Epoch 43/100\n",
      "60000/60000 [==============================] - 8s 125us/sample - loss: 42.8460 - vae_r_loss: 37.1553 - vae_kl_loss: 5.6907 - val_loss: 43.9027 - val_vae_r_loss: 38.3105 - val_vae_kl_loss: 5.5921\n",
      "Epoch 44/100\n",
      "60000/60000 [==============================] - 8s 125us/sample - loss: 42.8550 - vae_r_loss: 37.1409 - vae_kl_loss: 5.7141 - val_loss: 44.0527 - val_vae_r_loss: 38.4803 - val_vae_kl_loss: 5.5724\n",
      "Epoch 45/100\n",
      "60000/60000 [==============================] - 8s 125us/sample - loss: 42.7725 - vae_r_loss: 37.0586 - vae_kl_loss: 5.7139 - val_loss: 43.9695 - val_vae_r_loss: 38.1840 - val_vae_kl_loss: 5.7855\n",
      "Epoch 46/100\n",
      "60000/60000 [==============================] - 8s 127us/sample - loss: 42.7583 - vae_r_loss: 37.0431 - vae_kl_loss: 5.7152 - val_loss: 43.8917 - val_vae_r_loss: 38.4005 - val_vae_kl_loss: 5.4912\n",
      "Epoch 47/100\n",
      "60000/60000 [==============================] - 7s 112us/sample - loss: 42.7553 - vae_r_loss: 37.0322 - vae_kl_loss: 5.7231 - val_loss: 43.8994 - val_vae_r_loss: 38.2113 - val_vae_kl_loss: 5.6880\n",
      "Epoch 48/100\n",
      "60000/60000 [==============================] - 8s 125us/sample - loss: 42.7218 - vae_r_loss: 36.9855 - vae_kl_loss: 5.7363 - val_loss: 43.6855 - val_vae_r_loss: 37.9163 - val_vae_kl_loss: 5.7693\n",
      "Epoch 49/100\n",
      "60000/60000 [==============================] - 7s 125us/sample - loss: 42.6747 - vae_r_loss: 36.9308 - vae_kl_loss: 5.7439 - val_loss: 43.9899 - val_vae_r_loss: 38.4054 - val_vae_kl_loss: 5.5844\n",
      "Epoch 50/100\n",
      "60000/60000 [==============================] - 7s 125us/sample - loss: 42.6405 - vae_r_loss: 36.8940 - vae_kl_loss: 5.7464 - val_loss: 43.9136 - val_vae_r_loss: 38.1742 - val_vae_kl_loss: 5.7394\n",
      "Epoch 51/100\n",
      "60000/60000 [==============================] - 7s 119us/sample - loss: 42.6486 - vae_r_loss: 36.8904 - vae_kl_loss: 5.7582 - val_loss: 43.7776 - val_vae_r_loss: 37.8941 - val_vae_kl_loss: 5.8834\n",
      "Epoch 52/100\n",
      "60000/60000 [==============================] - 8s 126us/sample - loss: 42.5716 - vae_r_loss: 36.8282 - vae_kl_loss: 5.7433 - val_loss: 43.7207 - val_vae_r_loss: 37.9595 - val_vae_kl_loss: 5.7611\n",
      "Epoch 53/100\n",
      "60000/60000 [==============================] - 7s 124us/sample - loss: 42.5695 - vae_r_loss: 36.8049 - vae_kl_loss: 5.7646 - val_loss: 43.8533 - val_vae_r_loss: 38.1541 - val_vae_kl_loss: 5.6993\n",
      "Epoch 54/100\n",
      "60000/60000 [==============================] - 8s 126us/sample - loss: 42.5498 - vae_r_loss: 36.7861 - vae_kl_loss: 5.7637 - val_loss: 43.9121 - val_vae_r_loss: 38.2163 - val_vae_kl_loss: 5.6958\n",
      "Epoch 55/100\n",
      "60000/60000 [==============================] - 7s 125us/sample - loss: 42.5410 - vae_r_loss: 36.7715 - vae_kl_loss: 5.7695 - val_loss: 43.9402 - val_vae_r_loss: 38.2676 - val_vae_kl_loss: 5.6726\n",
      "Epoch 56/100\n",
      "60000/60000 [==============================] - 7s 118us/sample - loss: 42.5186 - vae_r_loss: 36.7312 - vae_kl_loss: 5.7875 - val_loss: 43.8019 - val_vae_r_loss: 38.0754 - val_vae_kl_loss: 5.7266\n",
      "Epoch 57/100\n",
      "60000/60000 [==============================] - 7s 120us/sample - loss: 42.4861 - vae_r_loss: 36.6955 - vae_kl_loss: 5.7906 - val_loss: 43.7560 - val_vae_r_loss: 37.9236 - val_vae_kl_loss: 5.8325\n",
      "Epoch 58/100\n",
      "60000/60000 [==============================] - 7s 120us/sample - loss: 42.4515 - vae_r_loss: 36.6663 - vae_kl_loss: 5.7851 - val_loss: 43.8697 - val_vae_r_loss: 37.9264 - val_vae_kl_loss: 5.9433\n",
      "Epoch 59/100\n",
      "60000/60000 [==============================] - 8s 129us/sample - loss: 42.4103 - vae_r_loss: 36.6236 - vae_kl_loss: 5.7867 - val_loss: 43.8263 - val_vae_r_loss: 37.9798 - val_vae_kl_loss: 5.8465\n",
      "Epoch 60/100\n",
      "60000/60000 [==============================] - 8s 125us/sample - loss: 42.3559 - vae_r_loss: 36.5556 - vae_kl_loss: 5.8003 - val_loss: 43.9342 - val_vae_r_loss: 38.3343 - val_vae_kl_loss: 5.5999\n",
      "Epoch 61/100\n",
      "60000/60000 [==============================] - 7s 117us/sample - loss: 42.4222 - vae_r_loss: 36.6162 - vae_kl_loss: 5.8060 - val_loss: 43.7412 - val_vae_r_loss: 37.9454 - val_vae_kl_loss: 5.7958\n",
      "Epoch 62/100\n",
      "60000/60000 [==============================] - 8s 126us/sample - loss: 42.3689 - vae_r_loss: 36.5495 - vae_kl_loss: 5.8194 - val_loss: 43.6502 - val_vae_r_loss: 37.7721 - val_vae_kl_loss: 5.8781\n",
      "Epoch 63/100\n",
      "60000/60000 [==============================] - 7s 121us/sample - loss: 42.3349 - vae_r_loss: 36.5133 - vae_kl_loss: 5.8216 - val_loss: 43.8532 - val_vae_r_loss: 38.1812 - val_vae_kl_loss: 5.6720\n",
      "Epoch 64/100\n",
      "60000/60000 [==============================] - 7s 124us/sample - loss: 42.3399 - vae_r_loss: 36.5160 - vae_kl_loss: 5.8239 - val_loss: 43.7407 - val_vae_r_loss: 37.7782 - val_vae_kl_loss: 5.9625\n",
      "Epoch 65/100\n",
      "60000/60000 [==============================] - 7s 122us/sample - loss: 42.3138 - vae_r_loss: 36.4957 - vae_kl_loss: 5.8181 - val_loss: 43.7347 - val_vae_r_loss: 37.8601 - val_vae_kl_loss: 5.8746\n",
      "Epoch 66/100\n",
      "60000/60000 [==============================] - 7s 123us/sample - loss: 42.2707 - vae_r_loss: 36.4429 - vae_kl_loss: 5.8278 - val_loss: 43.6608 - val_vae_r_loss: 37.7890 - val_vae_kl_loss: 5.8719\n",
      "Epoch 67/100\n",
      "60000/60000 [==============================] - 8s 125us/sample - loss: 42.2985 - vae_r_loss: 36.4611 - vae_kl_loss: 5.8374 - val_loss: 43.6500 - val_vae_r_loss: 37.8897 - val_vae_kl_loss: 5.7603\n",
      "Epoch 68/100\n",
      "60000/60000 [==============================] - 7s 125us/sample - loss: 42.2463 - vae_r_loss: 36.4053 - vae_kl_loss: 5.8411 - val_loss: 43.8904 - val_vae_r_loss: 38.1325 - val_vae_kl_loss: 5.7579\n",
      "Epoch 69/100\n",
      "60000/60000 [==============================] - 8s 125us/sample - loss: 42.2397 - vae_r_loss: 36.4007 - vae_kl_loss: 5.8389 - val_loss: 43.7959 - val_vae_r_loss: 38.0308 - val_vae_kl_loss: 5.7651\n",
      "Epoch 70/100\n",
      "60000/60000 [==============================] - 7s 124us/sample - loss: 42.2090 - vae_r_loss: 36.3648 - vae_kl_loss: 5.8442 - val_loss: 43.6900 - val_vae_r_loss: 37.9130 - val_vae_kl_loss: 5.7771\n",
      "Epoch 71/100\n",
      "60000/60000 [==============================] - 8s 126us/sample - loss: 42.1998 - vae_r_loss: 36.3484 - vae_kl_loss: 5.8514 - val_loss: 43.6552 - val_vae_r_loss: 37.8492 - val_vae_kl_loss: 5.8060\n",
      "Epoch 72/100\n",
      "60000/60000 [==============================] - 7s 123us/sample - loss: 42.1943 - vae_r_loss: 36.3286 - vae_kl_loss: 5.8657 - val_loss: 43.6515 - val_vae_r_loss: 37.9546 - val_vae_kl_loss: 5.6969\n",
      "Epoch 73/100\n",
      "60000/60000 [==============================] - 8s 125us/sample - loss: 42.1745 - vae_r_loss: 36.3193 - vae_kl_loss: 5.8552 - val_loss: 43.8444 - val_vae_r_loss: 38.1314 - val_vae_kl_loss: 5.7130\n",
      "Epoch 74/100\n",
      "60000/60000 [==============================] - 7s 124us/sample - loss: 42.1288 - vae_r_loss: 36.2784 - vae_kl_loss: 5.8504 - val_loss: 43.8137 - val_vae_r_loss: 38.0373 - val_vae_kl_loss: 5.7764\n",
      "Epoch 75/100\n",
      "60000/60000 [==============================] - 8s 126us/sample - loss: 42.1683 - vae_r_loss: 36.3033 - vae_kl_loss: 5.8650 - val_loss: 43.6371 - val_vae_r_loss: 37.9428 - val_vae_kl_loss: 5.6942\n",
      "Epoch 76/100\n",
      "60000/60000 [==============================] - 8s 125us/sample - loss: 42.1302 - vae_r_loss: 36.2609 - vae_kl_loss: 5.8692 - val_loss: 43.8022 - val_vae_r_loss: 37.9602 - val_vae_kl_loss: 5.8420\n",
      "Epoch 77/100\n",
      "60000/60000 [==============================] - 8s 125us/sample - loss: 42.1186 - vae_r_loss: 36.2503 - vae_kl_loss: 5.8684 - val_loss: 43.6853 - val_vae_r_loss: 37.9111 - val_vae_kl_loss: 5.7742\n",
      "Epoch 78/100\n",
      "60000/60000 [==============================] - 7s 124us/sample - loss: 42.1304 - vae_r_loss: 36.2550 - vae_kl_loss: 5.8754 - val_loss: 43.7015 - val_vae_r_loss: 37.8260 - val_vae_kl_loss: 5.8755\n",
      "Epoch 79/100\n",
      "60000/60000 [==============================] - 7s 124us/sample - loss: 42.0687 - vae_r_loss: 36.1876 - vae_kl_loss: 5.8811 - val_loss: 43.6678 - val_vae_r_loss: 37.7893 - val_vae_kl_loss: 5.8785\n",
      "Epoch 80/100\n",
      "60000/60000 [==============================] - 8s 125us/sample - loss: 42.0476 - vae_r_loss: 36.1643 - vae_kl_loss: 5.8833 - val_loss: 43.6656 - val_vae_r_loss: 37.8170 - val_vae_kl_loss: 5.8487\n",
      "Epoch 81/100\n",
      "60000/60000 [==============================] - 7s 124us/sample - loss: 42.0584 - vae_r_loss: 36.1825 - vae_kl_loss: 5.8759 - val_loss: 43.6267 - val_vae_r_loss: 37.8788 - val_vae_kl_loss: 5.7480\n",
      "Epoch 82/100\n",
      "60000/60000 [==============================] - 7s 111us/sample - loss: 42.0196 - vae_r_loss: 36.1357 - vae_kl_loss: 5.8840 - val_loss: 43.7281 - val_vae_r_loss: 37.6417 - val_vae_kl_loss: 6.0864\n",
      "Epoch 83/100\n",
      "60000/60000 [==============================] - 7s 125us/sample - loss: 42.0253 - vae_r_loss: 36.1311 - vae_kl_loss: 5.8943 - val_loss: 43.6205 - val_vae_r_loss: 37.8310 - val_vae_kl_loss: 5.7895\n",
      "Epoch 84/100\n",
      "60000/60000 [==============================] - 7s 124us/sample - loss: 42.0190 - vae_r_loss: 36.1276 - vae_kl_loss: 5.8914 - val_loss: 43.6444 - val_vae_r_loss: 37.8001 - val_vae_kl_loss: 5.8443\n",
      "Epoch 85/100\n",
      "60000/60000 [==============================] - 7s 124us/sample - loss: 42.0310 - vae_r_loss: 36.1556 - vae_kl_loss: 5.8754 - val_loss: 43.7125 - val_vae_r_loss: 37.8790 - val_vae_kl_loss: 5.8336\n",
      "Epoch 86/100\n",
      "60000/60000 [==============================] - 7s 124us/sample - loss: 42.0147 - vae_r_loss: 36.1276 - vae_kl_loss: 5.8872 - val_loss: 43.7536 - val_vae_r_loss: 37.7771 - val_vae_kl_loss: 5.9764\n",
      "Epoch 87/100\n",
      "60000/60000 [==============================] - 7s 124us/sample - loss: 41.9674 - vae_r_loss: 36.0723 - vae_kl_loss: 5.8951 - val_loss: 43.6899 - val_vae_r_loss: 37.8354 - val_vae_kl_loss: 5.8545\n",
      "Epoch 88/100\n",
      "60000/60000 [==============================] - 7s 122us/sample - loss: 41.9717 - vae_r_loss: 36.0720 - vae_kl_loss: 5.8998 - val_loss: 43.6792 - val_vae_r_loss: 37.9402 - val_vae_kl_loss: 5.7390\n",
      "Epoch 89/100\n",
      "60000/60000 [==============================] - 7s 124us/sample - loss: 41.9129 - vae_r_loss: 36.0126 - vae_kl_loss: 5.9003 - val_loss: 43.6925 - val_vae_r_loss: 37.8338 - val_vae_kl_loss: 5.8587\n",
      "Epoch 90/100\n",
      "60000/60000 [==============================] - 7s 124us/sample - loss: 41.9510 - vae_r_loss: 36.0328 - vae_kl_loss: 5.9181 - val_loss: 43.7327 - val_vae_r_loss: 37.8878 - val_vae_kl_loss: 5.8448\n",
      "Epoch 91/100\n",
      "60000/60000 [==============================] - 7s 125us/sample - loss: 41.9122 - vae_r_loss: 35.9966 - vae_kl_loss: 5.9155 - val_loss: 43.7091 - val_vae_r_loss: 37.8563 - val_vae_kl_loss: 5.8527\n",
      "Epoch 92/100\n",
      "60000/60000 [==============================] - 8s 126us/sample - loss: 41.9161 - vae_r_loss: 35.9930 - vae_kl_loss: 5.9231 - val_loss: 43.7270 - val_vae_r_loss: 37.8876 - val_vae_kl_loss: 5.8393\n",
      "Epoch 93/100\n",
      "60000/60000 [==============================] - 8s 126us/sample - loss: 41.9154 - vae_r_loss: 35.9875 - vae_kl_loss: 5.9278 - val_loss: 43.6541 - val_vae_r_loss: 37.6903 - val_vae_kl_loss: 5.9639\n",
      "Epoch 94/100\n",
      "60000/60000 [==============================] - 8s 126us/sample - loss: 41.8807 - vae_r_loss: 35.9663 - vae_kl_loss: 5.9144 - val_loss: 43.7643 - val_vae_r_loss: 37.8280 - val_vae_kl_loss: 5.9363\n",
      "Epoch 95/100\n",
      "60000/60000 [==============================] - 7s 125us/sample - loss: 41.9016 - vae_r_loss: 35.9739 - vae_kl_loss: 5.9277 - val_loss: 43.8913 - val_vae_r_loss: 37.9501 - val_vae_kl_loss: 5.9412\n",
      "Epoch 96/100\n",
      "60000/60000 [==============================] - 7s 120us/sample - loss: 41.8545 - vae_r_loss: 35.9314 - vae_kl_loss: 5.9231 - val_loss: 43.7067 - val_vae_r_loss: 37.7875 - val_vae_kl_loss: 5.9192\n",
      "Epoch 97/100\n",
      "60000/60000 [==============================] - 7s 118us/sample - loss: 41.8349 - vae_r_loss: 35.9267 - vae_kl_loss: 5.9083 - val_loss: 43.7083 - val_vae_r_loss: 37.7909 - val_vae_kl_loss: 5.9173\n",
      "Epoch 98/100\n",
      "60000/60000 [==============================] - 7s 123us/sample - loss: 41.8574 - vae_r_loss: 35.9213 - vae_kl_loss: 5.9361 - val_loss: 43.6804 - val_vae_r_loss: 37.8716 - val_vae_kl_loss: 5.8088\n",
      "Epoch 99/100\n",
      "60000/60000 [==============================] - 8s 126us/sample - loss: 41.8132 - vae_r_loss: 35.8820 - vae_kl_loss: 5.9312 - val_loss: 43.5919 - val_vae_r_loss: 37.7066 - val_vae_kl_loss: 5.8853\n",
      "Epoch 100/100\n",
      "60000/60000 [==============================] - 8s 126us/sample - loss: 41.8338 - vae_r_loss: 35.9009 - vae_kl_loss: 5.9329 - val_loss: 43.6792 - val_vae_r_loss: 37.6395 - val_vae_kl_loss: 6.0397\n",
      "\n",
      "Train duration : 750.18 sec. - 0:12:30\n"
     ]
    }
   ],
   "source": [
    "vae.train(x_train,\n",
    "          x_test,\n",
    "          batch_size        = batch_size, \n",
    "          epochs            = epochs,\n",
    "          initial_epoch     = initial_epoch,\n",
    "          k_size            = k_size\n",
    "         )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "---\n",
    "<img width=\"80px\" src=\"../fidle/img/00-Fidle-logo-01.svg\"></img>"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}