From bdfef613ae58f3556663f7426d224b291f7d5f0d Mon Sep 17 00:00:00 2001
From: "Jean-Luc Parouty Jean-Luc.Parouty@simap.grenoble-inp.fr"
 <paroutyj@f-dahu.u-ga.fr>
Date: Tue, 4 Feb 2020 19:02:23 +0100
Subject: [PATCH] Add save/load functionality to VAE model

---
 VAE/01-VAE with MNIST.ipynb      | 400 -------------------------------
 VAE/01-VAE-with-MNIST.ipynb      | 269 +++++++++++++++++++++
 VAE/02-VAE-with-MNIST-post.ipynb | 208 ++++++++++++++++
 VAE/modules/vae.py               | 107 ++++++---
 4 files changed, 554 insertions(+), 430 deletions(-)
 delete mode 100644 VAE/01-VAE with MNIST.ipynb
 create mode 100644 VAE/01-VAE-with-MNIST.ipynb
 create mode 100644 VAE/02-VAE-with-MNIST-post.ipynb

diff --git a/VAE/01-VAE with MNIST.ipynb b/VAE/01-VAE with MNIST.ipynb
deleted file mode 100644
index 8eb9c45..0000000
--- a/VAE/01-VAE with MNIST.ipynb	
+++ /dev/null
@@ -1,400 +0,0 @@
-{
- "cells": [
-  {
-   "cell_type": "markdown",
-   "metadata": {},
-   "source": [
-    "Variational AutoEncoder\n",
-    "=======================\n",
-    "---\n",
-    "Formation Introduction au Deep Learning  (FIDLE) - S. Arias, E. Maldonado, JL. Parouty - CNRS/SARI/DEVLOG - 2020  \n",
-    "\n",
-    "## Variational AutoEncoder (VAE), with MNIST Dataset\n",
-    "\n"
-   ]
-  },
-  {
-   "cell_type": "markdown",
-   "metadata": {},
-   "source": [
-    "## Step 1 - Init python stuff"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 1,
-   "metadata": {},
-   "outputs": [
-    {
-     "name": "stdout",
-     "output_type": "stream",
-     "text": [
-      "IDLE 2020 - Practical Work Module\n",
-      "  Version            : 0.2.5\n",
-      "  Run time           : Tuesday 4 February 2020, 00:10:15\n",
-      "  Matplotlib style   : ../fidle/talk.mplstyle\n",
-      "  TensorFlow version : 2.0.0\n",
-      "  Keras version      : 2.2.4-tf\n"
-     ]
-    }
-   ],
-   "source": [
-    "import numpy as np\n",
-    "\n",
-    "import tensorflow as tf\n",
-    "import tensorflow.keras as keras\n",
-    "import tensorflow.keras.datasets.mnist as mnist\n",
-    "\n",
-    "import modules.vae\n",
-    "\n",
-    "import matplotlib.pyplot as plt\n",
-    "import matplotlib\n",
-    "import seaborn as sns\n",
-    "\n",
-    "import os,sys,h5py,json\n",
-    "\n",
-    "from importlib import reload\n",
-    "\n",
-    "sys.path.append('..')\n",
-    "import fidle.pwk as ooo\n",
-    "\n",
-    "ooo.init()"
-   ]
-  },
-  {
-   "cell_type": "markdown",
-   "metadata": {},
-   "source": [
-    "## Step 2 - Get data"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 2,
-   "metadata": {},
-   "outputs": [
-    {
-     "name": "stdout",
-     "output_type": "stream",
-     "text": [
-      "(60000, 28, 28, 1)\n",
-      "(10000, 28, 28, 1)\n"
-     ]
-    }
-   ],
-   "source": [
-    "(x_train, y_train), (x_test, y_test) = mnist.load_data()\n",
-    "\n",
-    "x_train = x_train.astype('float32') / 255.\n",
-    "x_train = np.expand_dims(x_train, axis=3)\n",
-    "x_test  = x_test.astype('float32') / 255.\n",
-    "x_test  = np.expand_dims(x_test, axis=3)\n",
-    "print(x_train.shape)\n",
-    "print(x_test.shape)"
-   ]
-  },
-  {
-   "cell_type": "markdown",
-   "metadata": {},
-   "source": [
-    "## Step 3 - Get VAE model"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 3,
-   "metadata": {},
-   "outputs": [
-    {
-     "name": "stdout",
-     "output_type": "stream",
-     "text": [
-      "Model initialized.\n",
-      "Outputs will be in :  ./run/001\n"
-     ]
-    }
-   ],
-   "source": [
-    "# reload(modules.vae)\n",
-    "# reload(modules.callbacks)\n",
-    "\n",
-    "tag = '001'\n",
-    "\n",
-    "input_shape = (28,28,1)\n",
-    "z_dim       = 2\n",
-    "verbose     = 0\n",
-    "\n",
-    "encoder= [ {'type':'Conv2D', 'filters':32, 'kernel_size':(3,3), 'strides':1, 'padding':'same', 'activation':'relu'},\n",
-    "           {'type':'Conv2D', 'filters':64, 'kernel_size':(3,3), 'strides':2, 'padding':'same', 'activation':'relu'},\n",
-    "           {'type':'Conv2D', 'filters':64, 'kernel_size':(3,3), 'strides':2, 'padding':'same', 'activation':'relu'},\n",
-    "           {'type':'Conv2D', 'filters':64, 'kernel_size':(3,3), 'strides':1, 'padding':'same', 'activation':'relu'}\n",
-    "         ]\n",
-    "\n",
-    "decoder= [ {'type':'Conv2DT', 'filters':64, 'kernel_size':(3,3), 'strides':1, 'padding':'same', 'activation':'relu'},\n",
-    "           {'type':'Conv2DT', 'filters':64, 'kernel_size':(3,3), 'strides':2, 'padding':'same', 'activation':'relu'},\n",
-    "           {'type':'Conv2DT', 'filters':32, 'kernel_size':(3,3), 'strides':2, 'padding':'same', 'activation':'relu'},\n",
-    "           {'type':'Conv2DT', 'filters':1,  'kernel_size':(3,3), 'strides':1, 'padding':'same', 'activation':'sigmoid'}\n",
-    "         ]\n",
-    "\n",
-    "vae = modules.vae.VariationalAutoencoder(input_shape    = input_shape, \n",
-    "                                         encoder_layers = encoder, \n",
-    "                                         decoder_layers = decoder,\n",
-    "                                         z_dim          = z_dim, \n",
-    "                                         verbose        = verbose,\n",
-    "                                         run_tag        = tag)"
-   ]
-  },
-  {
-   "cell_type": "markdown",
-   "metadata": {},
-   "source": [
-    "## Step 4 - Compile it"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 4,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "learning_rate = 0.0005\n",
-    "r_loss_factor = 1000\n",
-    "\n",
-    "vae.compile(learning_rate, r_loss_factor)"
-   ]
-  },
-  {
-   "cell_type": "markdown",
-   "metadata": {},
-   "source": [
-    "## Step 5 - Train"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 5,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "batch_size        = 100\n",
-    "epochs            = 200\n",
-    "image_periodicity = 1      # for each epoch\n",
-    "chkpt_periodicity = 2      # for each epoch\n",
-    "initial_epoch     = 0\n",
-    "dataset_size      = 1"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [
-    {
-     "name": "stdout",
-     "output_type": "stream",
-     "text": [
-      "Train on 60000 samples, validate on 10000 samples\n",
-      "Epoch 1/200\n",
-      "  100/60000 [..............................] - ETA: 23:40 - loss: 231.4378 - vae_r_loss: 231.4373 - vae_kl_loss: 5.3801e-04WARNING:tensorflow:Method (on_train_batch_end) is slow compared to the batch update (0.251492). Check your callbacks.\n",
-      "60000/60000 [==============================] - 6s 101us/sample - loss: 67.7431 - vae_r_loss: 65.0691 - vae_kl_loss: 2.6740 - val_loss: 55.6598 - val_vae_r_loss: 52.4039 - val_vae_kl_loss: 3.2560\n",
-      "Epoch 2/200\n",
-      "60000/60000 [==============================] - 3s 58us/sample - loss: 54.0334 - vae_r_loss: 50.4695 - vae_kl_loss: 3.5639 - val_loss: 52.9105 - val_vae_r_loss: 49.1433 - val_vae_kl_loss: 3.7672\n",
-      "Epoch 3/200\n",
-      "60000/60000 [==============================] - 3s 57us/sample - loss: 51.8937 - vae_r_loss: 47.9195 - vae_kl_loss: 3.9743 - val_loss: 51.1775 - val_vae_r_loss: 47.0874 - val_vae_kl_loss: 4.0901\n",
-      "Epoch 4/200\n",
-      "60000/60000 [==============================] - 4s 59us/sample - loss: 50.4622 - vae_r_loss: 46.1359 - vae_kl_loss: 4.3264 - val_loss: 49.8507 - val_vae_r_loss: 45.2015 - val_vae_kl_loss: 4.6492\n",
-      "Epoch 5/200\n",
-      "60000/60000 [==============================] - 3s 57us/sample - loss: 49.3577 - vae_r_loss: 44.8123 - vae_kl_loss: 4.5454 - val_loss: 48.9416 - val_vae_r_loss: 44.3832 - val_vae_kl_loss: 4.5584\n",
-      "Epoch 6/200\n",
-      "60000/60000 [==============================] - 3s 58us/sample - loss: 48.5603 - vae_r_loss: 43.8800 - vae_kl_loss: 4.6803 - val_loss: 48.1800 - val_vae_r_loss: 43.5046 - val_vae_kl_loss: 4.6754\n",
-      "Epoch 7/200\n",
-      "60000/60000 [==============================] - 3s 57us/sample - loss: 48.0286 - vae_r_loss: 43.2646 - vae_kl_loss: 4.7640 - val_loss: 47.9362 - val_vae_r_loss: 43.2833 - val_vae_kl_loss: 4.6529\n",
-      "Epoch 8/200\n",
-      "60000/60000 [==============================] - 3s 58us/sample - loss: 47.6163 - vae_r_loss: 42.7828 - vae_kl_loss: 4.8336 - val_loss: 47.6161 - val_vae_r_loss: 42.7176 - val_vae_kl_loss: 4.8985\n",
-      "Epoch 9/200\n",
-      "60000/60000 [==============================] - 3s 58us/sample - loss: 47.2654 - vae_r_loss: 42.3804 - vae_kl_loss: 4.8850 - val_loss: 47.1385 - val_vae_r_loss: 42.2280 - val_vae_kl_loss: 4.9105\n",
-      "Epoch 10/200\n",
-      "WARNING:tensorflow:Method (on_train_batch_end) is slow compared to the batch update (0.732872). Check your callbacks.\n",
-      "  100/60000 [..............................] - ETA: 7:23 - loss: 47.8688 - vae_r_loss: 43.0966 - vae_kl_loss: 4.7722WARNING:tensorflow:Method (on_train_batch_end) is slow compared to the batch update (0.366450). Check your callbacks.\n",
-      "60000/60000 [==============================] - 4s 70us/sample - loss: 46.9698 - vae_r_loss: 42.0353 - vae_kl_loss: 4.9345 - val_loss: 47.0246 - val_vae_r_loss: 42.1103 - val_vae_kl_loss: 4.9143\n",
-      "Epoch 11/200\n",
-      "60000/60000 [==============================] - 3s 57us/sample - loss: 46.7538 - vae_r_loss: 41.7733 - vae_kl_loss: 4.9805 - val_loss: 46.9033 - val_vae_r_loss: 41.9019 - val_vae_kl_loss: 5.0014\n",
-      "Epoch 12/200\n",
-      "60000/60000 [==============================] - 3s 58us/sample - loss: 46.4962 - vae_r_loss: 41.4867 - vae_kl_loss: 5.0095 - val_loss: 46.6990 - val_vae_r_loss: 41.8006 - val_vae_kl_loss: 4.8985\n",
-      "Epoch 13/200\n",
-      "60000/60000 [==============================] - 3s 57us/sample - loss: 46.3232 - vae_r_loss: 41.2603 - vae_kl_loss: 5.0629 - val_loss: 46.6737 - val_vae_r_loss: 41.4675 - val_vae_kl_loss: 5.2061\n",
-      "Epoch 14/200\n",
-      "60000/60000 [==============================] - 3s 58us/sample - loss: 46.1505 - vae_r_loss: 41.0678 - vae_kl_loss: 5.0828 - val_loss: 46.3871 - val_vae_r_loss: 41.4687 - val_vae_kl_loss: 4.9184\n",
-      "Epoch 15/200\n",
-      "60000/60000 [==============================] - 3s 57us/sample - loss: 45.9750 - vae_r_loss: 40.8533 - vae_kl_loss: 5.1217 - val_loss: 46.1730 - val_vae_r_loss: 41.0982 - val_vae_kl_loss: 5.0748\n",
-      "Epoch 16/200\n",
-      "60000/60000 [==============================] - 3s 57us/sample - loss: 45.8053 - vae_r_loss: 40.6467 - vae_kl_loss: 5.1586 - val_loss: 46.2439 - val_vae_r_loss: 41.1142 - val_vae_kl_loss: 5.1297\n",
-      "Epoch 17/200\n",
-      "60000/60000 [==============================] - 3s 57us/sample - loss: 45.6415 - vae_r_loss: 40.4657 - vae_kl_loss: 5.1758 - val_loss: 46.0754 - val_vae_r_loss: 41.0632 - val_vae_kl_loss: 5.0122\n",
-      "Epoch 18/200\n",
-      "60000/60000 [==============================] - 3s 58us/sample - loss: 45.5121 - vae_r_loss: 40.3147 - vae_kl_loss: 5.1974 - val_loss: 45.8663 - val_vae_r_loss: 40.5329 - val_vae_kl_loss: 5.3334\n",
-      "Epoch 19/200\n",
-      "60000/60000 [==============================] - 3s 56us/sample - loss: 45.3686 - vae_r_loss: 40.1475 - vae_kl_loss: 5.2211 - val_loss: 46.2054 - val_vae_r_loss: 41.1238 - val_vae_kl_loss: 5.0816\n",
-      "Epoch 20/200\n",
-      "60000/60000 [==============================] - 3s 58us/sample - loss: 45.2161 - vae_r_loss: 39.9703 - vae_kl_loss: 5.2458 - val_loss: 45.7448 - val_vae_r_loss: 40.6166 - val_vae_kl_loss: 5.1283\n",
-      "Epoch 21/200\n",
-      "60000/60000 [==============================] - 3s 56us/sample - loss: 45.1159 - vae_r_loss: 39.8419 - vae_kl_loss: 5.2740 - val_loss: 45.8612 - val_vae_r_loss: 40.8692 - val_vae_kl_loss: 4.9920\n",
-      "Epoch 22/200\n",
-      "60000/60000 [==============================] - 3s 57us/sample - loss: 44.9881 - vae_r_loss: 39.7023 - vae_kl_loss: 5.2857 - val_loss: 45.8085 - val_vae_r_loss: 40.2675 - val_vae_kl_loss: 5.5410\n",
-      "Epoch 23/200\n",
-      "60000/60000 [==============================] - 3s 57us/sample - loss: 44.8471 - vae_r_loss: 39.5384 - vae_kl_loss: 5.3087 - val_loss: 45.4330 - val_vae_r_loss: 40.0743 - val_vae_kl_loss: 5.3587\n",
-      "Epoch 24/200\n",
-      "60000/60000 [==============================] - 3s 58us/sample - loss: 44.7550 - vae_r_loss: 39.4362 - vae_kl_loss: 5.3188 - val_loss: 45.3320 - val_vae_r_loss: 39.9992 - val_vae_kl_loss: 5.3328\n",
-      "Epoch 25/200\n",
-      "60000/60000 [==============================] - 3s 57us/sample - loss: 44.6692 - vae_r_loss: 39.3461 - vae_kl_loss: 5.3232 - val_loss: 45.3552 - val_vae_r_loss: 40.0258 - val_vae_kl_loss: 5.3294\n",
-      "Epoch 26/200\n",
-      "60000/60000 [==============================] - 3s 58us/sample - loss: 44.5891 - vae_r_loss: 39.2333 - vae_kl_loss: 5.3558 - val_loss: 45.2681 - val_vae_r_loss: 39.9015 - val_vae_kl_loss: 5.3666\n",
-      "Epoch 27/200\n",
-      "60000/60000 [==============================] - 3s 56us/sample - loss: 44.5072 - vae_r_loss: 39.1374 - vae_kl_loss: 5.3698 - val_loss: 45.3209 - val_vae_r_loss: 39.9636 - val_vae_kl_loss: 5.3574\n",
-      "Epoch 28/200\n",
-      "60000/60000 [==============================] - 3s 58us/sample - loss: 44.4180 - vae_r_loss: 39.0149 - vae_kl_loss: 5.4031 - val_loss: 45.2435 - val_vae_r_loss: 39.7765 - val_vae_kl_loss: 5.4671\n",
-      "Epoch 29/200\n",
-      "60000/60000 [==============================] - 3s 57us/sample - loss: 44.3102 - vae_r_loss: 38.9046 - vae_kl_loss: 5.4057 - val_loss: 45.2258 - val_vae_r_loss: 39.8441 - val_vae_kl_loss: 5.3817\n",
-      "Epoch 30/200\n",
-      "60000/60000 [==============================] - 3s 58us/sample - loss: 44.2489 - vae_r_loss: 38.8299 - vae_kl_loss: 5.4190 - val_loss: 45.0044 - val_vae_r_loss: 39.6516 - val_vae_kl_loss: 5.3528\n",
-      "Epoch 31/200\n",
-      "60000/60000 [==============================] - 3s 57us/sample - loss: 44.1732 - vae_r_loss: 38.7482 - vae_kl_loss: 5.4249 - val_loss: 45.0000 - val_vae_r_loss: 39.5609 - val_vae_kl_loss: 5.4391\n",
-      "Epoch 32/200\n",
-      "60000/60000 [==============================] - 3s 58us/sample - loss: 44.0894 - vae_r_loss: 38.6580 - vae_kl_loss: 5.4314 - val_loss: 44.9769 - val_vae_r_loss: 39.5384 - val_vae_kl_loss: 5.4385\n",
-      "Epoch 33/200\n",
-      "60000/60000 [==============================] - 3s 57us/sample - loss: 44.0582 - vae_r_loss: 38.6092 - vae_kl_loss: 5.4490 - val_loss: 44.9346 - val_vae_r_loss: 39.3805 - val_vae_kl_loss: 5.5541\n",
-      "Epoch 34/200\n",
-      "60000/60000 [==============================] - 3s 57us/sample - loss: 43.9458 - vae_r_loss: 38.4818 - vae_kl_loss: 5.4640 - val_loss: 45.0624 - val_vae_r_loss: 39.5811 - val_vae_kl_loss: 5.4813\n",
-      "Epoch 35/200\n",
-      "60000/60000 [==============================] - 3s 56us/sample - loss: 43.8850 - vae_r_loss: 38.4031 - vae_kl_loss: 5.4819 - val_loss: 45.0285 - val_vae_r_loss: 39.5350 - val_vae_kl_loss: 5.4935\n",
-      "Epoch 36/200\n",
-      "60000/60000 [==============================] - 3s 58us/sample - loss: 43.8698 - vae_r_loss: 38.3779 - vae_kl_loss: 5.4918 - val_loss: 44.9170 - val_vae_r_loss: 39.5714 - val_vae_kl_loss: 5.3456\n",
-      "Epoch 37/200\n",
-      "60000/60000 [==============================] - 3s 57us/sample - loss: 43.7739 - vae_r_loss: 38.2723 - vae_kl_loss: 5.5016 - val_loss: 44.8441 - val_vae_r_loss: 39.3665 - val_vae_kl_loss: 5.4776\n",
-      "Epoch 38/200\n",
-      "60000/60000 [==============================] - 3s 57us/sample - loss: 43.7084 - vae_r_loss: 38.1933 - vae_kl_loss: 5.5151 - val_loss: 44.9233 - val_vae_r_loss: 39.5526 - val_vae_kl_loss: 5.3706\n",
-      "Epoch 39/200\n",
-      "60000/60000 [==============================] - 3s 57us/sample - loss: 43.6626 - vae_r_loss: 38.1320 - vae_kl_loss: 5.5306 - val_loss: 44.6793 - val_vae_r_loss: 39.2304 - val_vae_kl_loss: 5.4489\n",
-      "Epoch 40/200\n",
-      "60000/60000 [==============================] - 3s 58us/sample - loss: 43.5838 - vae_r_loss: 38.0592 - vae_kl_loss: 5.5246 - val_loss: 44.6130 - val_vae_r_loss: 39.0715 - val_vae_kl_loss: 5.5415\n",
-      "Epoch 41/200\n",
-      "60000/60000 [==============================] - 3s 56us/sample - loss: 43.5194 - vae_r_loss: 37.9840 - vae_kl_loss: 5.5354 - val_loss: 44.8512 - val_vae_r_loss: 39.6158 - val_vae_kl_loss: 5.2354\n",
-      "Epoch 42/200\n",
-      "60000/60000 [==============================] - 3s 57us/sample - loss: 43.5129 - vae_r_loss: 37.9786 - vae_kl_loss: 5.5343 - val_loss: 44.6991 - val_vae_r_loss: 39.2098 - val_vae_kl_loss: 5.4894\n",
-      "Epoch 43/200\n",
-      "60000/60000 [==============================] - 3s 56us/sample - loss: 43.4707 - vae_r_loss: 37.9237 - vae_kl_loss: 5.5470 - val_loss: 44.7121 - val_vae_r_loss: 39.2446 - val_vae_kl_loss: 5.4675\n",
-      "Epoch 44/200\n",
-      "60000/60000 [==============================] - 3s 57us/sample - loss: 43.3832 - vae_r_loss: 37.8227 - vae_kl_loss: 5.5604 - val_loss: 44.9172 - val_vae_r_loss: 39.3446 - val_vae_kl_loss: 5.5726\n",
-      "Epoch 45/200\n",
-      "60000/60000 [==============================] - 3s 57us/sample - loss: 43.3868 - vae_r_loss: 37.8075 - vae_kl_loss: 5.5793 - val_loss: 44.5718 - val_vae_r_loss: 39.0284 - val_vae_kl_loss: 5.5434\n",
-      "Epoch 46/200\n",
-      "60000/60000 [==============================] - 3s 57us/sample - loss: 43.2774 - vae_r_loss: 37.6953 - vae_kl_loss: 5.5821 - val_loss: 44.6954 - val_vae_r_loss: 39.1276 - val_vae_kl_loss: 5.5678\n",
-      "Epoch 47/200\n",
-      "60000/60000 [==============================] - 3s 56us/sample - loss: 43.2765 - vae_r_loss: 37.6813 - vae_kl_loss: 5.5952 - val_loss: 44.6153 - val_vae_r_loss: 38.9606 - val_vae_kl_loss: 5.6547\n",
-      "Epoch 48/200\n",
-      "60000/60000 [==============================] - 3s 58us/sample - loss: 43.2385 - vae_r_loss: 37.6431 - vae_kl_loss: 5.5954 - val_loss: 44.5508 - val_vae_r_loss: 39.0830 - val_vae_kl_loss: 5.4678\n",
-      "Epoch 49/200\n",
-      "60000/60000 [==============================] - 3s 56us/sample - loss: 43.1847 - vae_r_loss: 37.5822 - vae_kl_loss: 5.6025 - val_loss: 44.8277 - val_vae_r_loss: 39.1688 - val_vae_kl_loss: 5.6589\n",
-      "Epoch 50/200\n",
-      "60000/60000 [==============================] - 4s 58us/sample - loss: 43.1557 - vae_r_loss: 37.5533 - vae_kl_loss: 5.6024 - val_loss: 44.5082 - val_vae_r_loss: 38.9529 - val_vae_kl_loss: 5.5553\n",
-      "Epoch 51/200\n",
-      "60000/60000 [==============================] - 3s 56us/sample - loss: 43.0726 - vae_r_loss: 37.4533 - vae_kl_loss: 5.6193 - val_loss: 44.6332 - val_vae_r_loss: 38.9104 - val_vae_kl_loss: 5.7228\n",
-      "Epoch 52/200\n",
-      "60000/60000 [==============================] - 3s 57us/sample - loss: 43.1003 - vae_r_loss: 37.4708 - vae_kl_loss: 5.6295 - val_loss: 44.5279 - val_vae_r_loss: 39.0846 - val_vae_kl_loss: 5.4433\n",
-      "Epoch 53/200\n",
-      "60000/60000 [==============================] - 3s 56us/sample - loss: 43.0121 - vae_r_loss: 37.3923 - vae_kl_loss: 5.6198 - val_loss: 44.5675 - val_vae_r_loss: 38.9651 - val_vae_kl_loss: 5.6024\n",
-      "Epoch 54/200\n",
-      "60000/60000 [==============================] - 3s 57us/sample - loss: 42.9750 - vae_r_loss: 37.3273 - vae_kl_loss: 5.6477 - val_loss: 44.6084 - val_vae_r_loss: 39.0057 - val_vae_kl_loss: 5.6027\n",
-      "Epoch 55/200\n",
-      "60000/60000 [==============================] - 3s 57us/sample - loss: 42.9669 - vae_r_loss: 37.3124 - vae_kl_loss: 5.6545 - val_loss: 44.4369 - val_vae_r_loss: 38.7499 - val_vae_kl_loss: 5.6870\n",
-      "Epoch 56/200\n",
-      "60000/60000 [==============================] - 3s 57us/sample - loss: 42.9172 - vae_r_loss: 37.2666 - vae_kl_loss: 5.6506 - val_loss: 44.4817 - val_vae_r_loss: 38.8071 - val_vae_kl_loss: 5.6747\n",
-      "Epoch 57/200\n",
-      "60000/60000 [==============================] - 3s 56us/sample - loss: 42.8719 - vae_r_loss: 37.2088 - vae_kl_loss: 5.6630 - val_loss: 44.7545 - val_vae_r_loss: 39.1340 - val_vae_kl_loss: 5.6205\n",
-      "Epoch 58/200\n",
-      "60000/60000 [==============================] - 3s 57us/sample - loss: 42.8724 - vae_r_loss: 37.2070 - vae_kl_loss: 5.6654 - val_loss: 44.4428 - val_vae_r_loss: 38.8374 - val_vae_kl_loss: 5.6054\n",
-      "Epoch 59/200\n",
-      "60000/60000 [==============================] - 3s 57us/sample - loss: 42.8085 - vae_r_loss: 37.1356 - vae_kl_loss: 5.6729 - val_loss: 44.3657 - val_vae_r_loss: 38.8973 - val_vae_kl_loss: 5.4684\n",
-      "Epoch 60/200\n",
-      "60000/60000 [==============================] - 3s 57us/sample - loss: 42.7711 - vae_r_loss: 37.1025 - vae_kl_loss: 5.6687 - val_loss: 44.5526 - val_vae_r_loss: 38.7923 - val_vae_kl_loss: 5.7603\n",
-      "Epoch 61/200\n",
-      "60000/60000 [==============================] - 3s 56us/sample - loss: 42.7549 - vae_r_loss: 37.0712 - vae_kl_loss: 5.6837 - val_loss: 44.6274 - val_vae_r_loss: 39.1211 - val_vae_kl_loss: 5.5063\n",
-      "Epoch 62/200\n",
-      "60000/60000 [==============================] - 3s 57us/sample - loss: 42.7314 - vae_r_loss: 37.0368 - vae_kl_loss: 5.6946 - val_loss: 44.3828 - val_vae_r_loss: 38.8327 - val_vae_kl_loss: 5.5502\n",
-      "Epoch 63/200\n",
-      "60000/60000 [==============================] - 3s 56us/sample - loss: 42.6688 - vae_r_loss: 36.9835 - vae_kl_loss: 5.6853 - val_loss: 44.4869 - val_vae_r_loss: 38.8497 - val_vae_kl_loss: 5.6372\n",
-      "Epoch 64/200\n",
-      "60000/60000 [==============================] - 3s 57us/sample - loss: 42.6714 - vae_r_loss: 36.9633 - vae_kl_loss: 5.7080 - val_loss: 44.4562 - val_vae_r_loss: 38.7178 - val_vae_kl_loss: 5.7384\n",
-      "Epoch 65/200\n",
-      "60000/60000 [==============================] - 3s 56us/sample - loss: 42.6547 - vae_r_loss: 36.9360 - vae_kl_loss: 5.7187 - val_loss: 44.4947 - val_vae_r_loss: 38.8561 - val_vae_kl_loss: 5.6386\n",
-      "Epoch 66/200\n",
-      "60000/60000 [==============================] - 3s 57us/sample - loss: 42.5807 - vae_r_loss: 36.8625 - vae_kl_loss: 5.7182 - val_loss: 44.4270 - val_vae_r_loss: 38.7251 - val_vae_kl_loss: 5.7019\n",
-      "Epoch 67/200\n",
-      "60000/60000 [==============================] - 3s 56us/sample - loss: 42.5664 - vae_r_loss: 36.8466 - vae_kl_loss: 5.7197 - val_loss: 44.5878 - val_vae_r_loss: 38.8787 - val_vae_kl_loss: 5.7091\n",
-      "Epoch 68/200\n",
-      "60000/60000 [==============================] - 3s 57us/sample - loss: 42.5503 - vae_r_loss: 36.8269 - vae_kl_loss: 5.7235 - val_loss: 44.6236 - val_vae_r_loss: 38.8846 - val_vae_kl_loss: 5.7390\n",
-      "Epoch 69/200\n",
-      "60000/60000 [==============================] - 3s 56us/sample - loss: 42.5057 - vae_r_loss: 36.7706 - vae_kl_loss: 5.7352 - val_loss: 44.5720 - val_vae_r_loss: 38.9196 - val_vae_kl_loss: 5.6525\n",
-      "Epoch 70/200\n",
-      "60000/60000 [==============================] - 3s 57us/sample - loss: 42.4955 - vae_r_loss: 36.7553 - vae_kl_loss: 5.7402 - val_loss: 44.4059 - val_vae_r_loss: 38.8886 - val_vae_kl_loss: 5.5173\n",
-      "Epoch 71/200\n",
-      "60000/60000 [==============================] - 3s 56us/sample - loss: 42.4649 - vae_r_loss: 36.7251 - vae_kl_loss: 5.7398 - val_loss: 44.5864 - val_vae_r_loss: 38.8203 - val_vae_kl_loss: 5.7661\n",
-      "Epoch 72/200\n",
-      "60000/60000 [==============================] - 3s 58us/sample - loss: 42.4907 - vae_r_loss: 36.7440 - vae_kl_loss: 5.7467 - val_loss: 44.3493 - val_vae_r_loss: 38.6765 - val_vae_kl_loss: 5.6727\n",
-      "Epoch 73/200\n",
-      "60000/60000 [==============================] - 3s 56us/sample - loss: 42.4224 - vae_r_loss: 36.6558 - vae_kl_loss: 5.7666 - val_loss: 44.5477 - val_vae_r_loss: 38.7588 - val_vae_kl_loss: 5.7889\n",
-      "Epoch 74/200\n",
-      "43100/60000 [====================>.........] - ETA: 0s - loss: 42.3141 - vae_r_loss: 36.5576 - vae_kl_loss: 5.7565"
-     ]
-    }
-   ],
-   "source": [
-    "vae.train(x_train,\n",
-    "          x_test,\n",
-    "          batch_size        = batch_size, \n",
-    "          epochs            = epochs,\n",
-    "          image_periodicity = image_periodicity,\n",
-    "          chkpt_periodicity = chkpt_periodicity,\n",
-    "          initial_epoch     = initial_epoch,\n",
-    "          dataset_size      = dataset_size,\n",
-    "          lr_decay          = 1\n",
-    "         )"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": []
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": []
-  }
- ],
- "metadata": {
-  "kernelspec": {
-   "display_name": "Python 3",
-   "language": "python",
-   "name": "python3"
-  },
-  "language_info": {
-   "codemirror_mode": {
-    "name": "ipython",
-    "version": 3
-   },
-   "file_extension": ".py",
-   "mimetype": "text/x-python",
-   "name": "python",
-   "nbconvert_exporter": "python",
-   "pygments_lexer": "ipython3",
-   "version": "3.7.5"
-  }
- },
- "nbformat": 4,
- "nbformat_minor": 4
-}
diff --git a/VAE/01-VAE-with-MNIST.ipynb b/VAE/01-VAE-with-MNIST.ipynb
new file mode 100644
index 0000000..ae68bca
--- /dev/null
+++ b/VAE/01-VAE-with-MNIST.ipynb
@@ -0,0 +1,269 @@
+{
+ "cells": [
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "Variational AutoEncoder (VAE) with MNIST\n",
+    "========================================\n",
+    "---\n",
+    "Formation Introduction au Deep Learning  (FIDLE) - S. Arias, E. Maldonado, JL. Parouty - CNRS/SARI/DEVLOG - 2020  \n",
+    "\n",
+    "## Episode 1 - Train a model\n",
+    " - Defining a VAE model\n",
+    " - Build the model\n",
+    " - Train it\n",
+    " - Follow the learning process with Tensorboard\n"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Step 1 - Init python stuff"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 65,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "FIDLE 2020 - Variational AutoEncoder (VAE)\n",
+      "TensorFlow version : 2.0.0\n"
+     ]
+    }
+   ],
+   "source": [
+    "import numpy as np\n",
+    "import tensorflow as tf\n",
+    "import tensorflow.keras.datasets.mnist as mnist\n",
+    "import sys, importlib\n",
+    "\n",
+    "import modules.vae\n",
+    "importlib.reload(modules.vae)\n",
+    "\n",
+    "print('FIDLE 2020 - Variational AutoEncoder (VAE)')\n",
+    "print('TensorFlow version :',tf.__version__)\n"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Step 2 - Get data"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 66,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "Dataset loaded.\n",
+      "x_train shape : (60000, 28, 28, 1)\n",
+      "x_test_shape  : (10000, 28, 28, 1)\n"
+     ]
+    }
+   ],
+   "source": [
+    "(x_train, y_train), (x_test, y_test) = mnist.load_data()\n",
+    "\n",
+    "x_train = x_train.astype('float32') / 255.\n",
+    "x_train = np.expand_dims(x_train, axis=3)\n",
+    "x_test  = x_test.astype('float32') / 255.\n",
+    "x_test  = np.expand_dims(x_test, axis=3)\n",
+    "print('Dataset loaded.')\n",
+    "print(f'x_train shape : {x_train.shape}\\nx_test_shape  : {x_test.shape}')"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Step 3 - Get VAE model"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 67,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "Model initialized.\n",
+      "Outputs will be in :  ./run/004\n"
+     ]
+    }
+   ],
+   "source": [
+    "tag = '004'\n",
+    "\n",
+    "input_shape = (28,28,1)\n",
+    "z_dim       = 2\n",
+    "verbose     = 0\n",
+    "\n",
+    "encoder= [ {'type':'Conv2D', 'filters':32, 'kernel_size':(3,3), 'strides':1, 'padding':'same', 'activation':'relu'},\n",
+    "           {'type':'Conv2D', 'filters':64, 'kernel_size':(3,3), 'strides':2, 'padding':'same', 'activation':'relu'},\n",
+    "           {'type':'Conv2D', 'filters':64, 'kernel_size':(3,3), 'strides':2, 'padding':'same', 'activation':'relu'},\n",
+    "           {'type':'Conv2D', 'filters':64, 'kernel_size':(3,3), 'strides':1, 'padding':'same', 'activation':'relu'}\n",
+    "         ]\n",
+    "\n",
+    "decoder= [ {'type':'Conv2DT', 'filters':64, 'kernel_size':(3,3), 'strides':1, 'padding':'same', 'activation':'relu'},\n",
+    "           {'type':'Conv2DT', 'filters':64, 'kernel_size':(3,3), 'strides':2, 'padding':'same', 'activation':'relu'},\n",
+    "           {'type':'Conv2DT', 'filters':32, 'kernel_size':(3,3), 'strides':2, 'padding':'same', 'activation':'relu'},\n",
+    "           {'type':'Conv2DT', 'filters':1,  'kernel_size':(3,3), 'strides':1, 'padding':'same', 'activation':'sigmoid'}\n",
+    "         ]\n",
+    "\n",
+    "vae = modules.vae.VariationalAutoencoder(input_shape    = input_shape, \n",
+    "                                         encoder_layers = encoder, \n",
+    "                                         decoder_layers = decoder,\n",
+    "                                         z_dim          = z_dim, \n",
+    "                                         verbose        = verbose,\n",
+    "                                         run_tag        = tag)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 68,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "Config saved in : ./run/004/models/vae_config.json\n",
+      "Model saved in  : ./run/004/models/model.h5\n"
+     ]
+    }
+   ],
+   "source": [
+    "vae.save()"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 69,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "dict_keys(['input_shape', 'encoder_layers', 'decoder_layers', 'z_dim', 'run_tag', 'verbose'])\n",
+      "Model initialized.\n",
+      "Outputs will be in :  ./run/004\n",
+      "Weights loaded from : ./run/004/models/model.h5\n"
+     ]
+    }
+   ],
+   "source": [
+    "vae=modules.vae.VariationalAutoencoder.load('004')"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Step 4 - Compile it"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 70,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "Compiled.\n",
+      "Optimizer is Adam with learning_rate=0.0005\n"
+     ]
+    }
+   ],
+   "source": [
+    "learning_rate = 0.0005\n",
+    "r_loss_factor = 1000\n",
+    "\n",
+    "vae.compile(learning_rate, r_loss_factor)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Step 5 - Train"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "batch_size        = 100\n",
+    "epochs            = 100\n",
+    "image_periodicity = 1      # for each epoch\n",
+    "chkpt_periodicity = 2      # for each epoch\n",
+    "initial_epoch     = 0\n",
+    "dataset_size      = 1"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "vae.train(x_train,\n",
+    "          x_test,\n",
+    "          batch_size        = batch_size, \n",
+    "          epochs            = epochs,\n",
+    "          image_periodicity = image_periodicity,\n",
+    "          chkpt_periodicity = chkpt_periodicity,\n",
+    "          initial_epoch     = initial_epoch,\n",
+    "          dataset_size      = dataset_size\n",
+    "         )"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "vae."
+   ]
+  }
+ ],
+ "metadata": {
+  "kernelspec": {
+   "display_name": "Python 3",
+   "language": "python",
+   "name": "python3"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.7.5"
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 4
+}
diff --git a/VAE/02-VAE-with-MNIST-post.ipynb b/VAE/02-VAE-with-MNIST-post.ipynb
new file mode 100644
index 0000000..64e2a11
--- /dev/null
+++ b/VAE/02-VAE-with-MNIST-post.ipynb
@@ -0,0 +1,208 @@
+{
+ "cells": [
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "Variational AutoEncoder (VAE) with MNIST\n",
+    "========================================\n",
+    "---\n",
+    "Formation Introduction au Deep Learning  (FIDLE) - S. Arias, E. Maldonado, JL. Parouty - CNRS/SARI/DEVLOG - 2020  \n",
+    "\n",
+    "## Episode 2 - Analyse our trained model\n",
+    " - Defining a VAE model\n",
+    " - Build the model\n",
+    " - Train it\n",
+    " - Follow the learning process with Tensorboard\n"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Step 1 - Init python stuff"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 2,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "FIDLE 2020 - Variational AutoEncoder (VAE)\n",
+      "TensorFlow version : 2.0.0\n"
+     ]
+    }
+   ],
+   "source": [
+    "import numpy as np\n",
+    "import tensorflow as tf\n",
+    "import tensorflow.keras.datasets.mnist as mnist\n",
+    "import sys, importlib\n",
+    "\n",
+    "import modules.vae\n",
+    "importlib.reload(modules.vae)\n",
+    "\n",
+    "print('FIDLE 2020 - Variational AutoEncoder (VAE)')\n",
+    "print('TensorFlow version :',tf.__version__)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Step 2 - Get data"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 3,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "Dataset loaded.\n",
+      "x_train shape : (60000, 28, 28, 1)\n",
+      "x_test_shape  : (10000, 28, 28, 1)\n"
+     ]
+    }
+   ],
+   "source": [
+    "(x_train, y_train), (x_test, y_test) = mnist.load_data()\n",
+    "\n",
+    "x_train = x_train.astype('float32') / 255.\n",
+    "x_train = np.expand_dims(x_train, axis=3)\n",
+    "x_test  = x_test.astype('float32') / 255.\n",
+    "x_test  = np.expand_dims(x_test, axis=3)\n",
+    "print('Dataset loaded.')\n",
+    "print(f'x_train shape : {x_train.shape}\\nx_test_shape  : {x_test.shape}')"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Step 3 - Load best model"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "vae\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "end_time  = time.time()\n",
+    "dt  = end_time-start_time\n",
+    "dth = str(datetime.timedelta(seconds=dt))\n",
+    "print(f'\\nTrain duration : {dt:.2f} sec. - {dth:}')"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Step 4 - Compile it"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "learning_rate = 0.0005\n",
+    "r_loss_factor = 1000\n",
+    "\n",
+    "vae.compile(learning_rate, r_loss_factor)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Step 5 - Train"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "batch_size        = 100\n",
+    "epochs            = 200\n",
+    "image_periodicity = 1      # for each epoch\n",
+    "chkpt_periodicity = 2      # for each epoch\n",
+    "initial_epoch     = 0\n",
+    "dataset_size      = 1"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "vae.train(x_train,\n",
+    "          x_test,\n",
+    "          batch_size        = batch_size, \n",
+    "          epochs            = epochs,\n",
+    "          image_periodicity = image_periodicity,\n",
+    "          chkpt_periodicity = chkpt_periodicity,\n",
+    "          initial_epoch     = initial_epoch,\n",
+    "          dataset_size      = dataset_size,\n",
+    "          lr_decay          = 1\n",
+    "         )"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": []
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": []
+  }
+ ],
+ "metadata": {
+  "kernelspec": {
+   "display_name": "Python 3",
+   "language": "python",
+   "name": "python3"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.7.5"
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 4
+}
diff --git a/VAE/modules/vae.py b/VAE/modules/vae.py
index e98409b..207cfc1 100644
--- a/VAE/modules/vae.py
+++ b/VAE/modules/vae.py
@@ -13,20 +13,21 @@ from tensorflow.keras.utils import plot_model
 import tensorflow.keras.datasets.imdb as imdb
 
 import modules.callbacks
-import os
+import os, json, time, datetime
 
 
 
 class VariationalAutoencoder():
 
     
-    def __init__(self, input_shape=None, encoder_layers=None, decoder_layers=None, z_dim=None, run_tag='default', verbose=0):
-        
+    def __init__(self, input_shape=None, encoder_layers=None, decoder_layers=None, z_dim=None, run_tag='000', verbose=0):
+               
         self.name           = 'Variational AutoEncoder'
-        self.input_shape    = input_shape
+        self.input_shape    = list(input_shape)
         self.encoder_layers = encoder_layers
         self.decoder_layers = decoder_layers
         self.z_dim          = z_dim
+        self.run_tag        = str(run_tag)
         self.verbose        = verbose
         self.run_directory  = f'./run/{run_tag}'
         
@@ -42,13 +43,14 @@ class VariationalAutoencoder():
         
         # ---- Add next layers
         i=1
-        for params in encoder_layers:
-            t=params['type']
-            params.pop('type')
-            if t=='Conv2D':
-                layer = Conv2D(**params, name=f"Layer_{i}")
-            if t=='Dropout':
-                layer = Dropout(**params)
+        for l_config in encoder_layers:
+            l_type   = l_config['type']
+            l_params = l_config.copy()
+            l_params.pop('type')
+            if l_type=='Conv2D':
+                layer = Conv2D(**l_params)
+            if l_type=='Dropout':
+                layer = Dropout(**l_params)
             x = layer(x)
             i+=1
             
@@ -83,13 +85,14 @@ class VariationalAutoencoder():
 
         # ---- Add next layers
         i=1
-        for params in decoder_layers:
-            t=params['type']
-            params.pop('type')
-            if t=='Conv2DT':
-                layer = Conv2DTranspose(**params, name=f"Layer_{i}")
-            if t=='Dropout':
-                layer = Dropout(**params)
+        for l_config in decoder_layers:
+            l_type   = l_config['type']
+            l_params = l_config.copy()
+            l_params.pop('type')
+            if l_type=='Conv2DT':
+                layer = Conv2DTranspose(**l_params)
+            if l_type=='Dropout':
+                layer = Dropout(**l_params)
             x = layer(x)
             i+=1
 
@@ -140,6 +143,8 @@ class VariationalAutoencoder():
                            loss = vae_loss,
                            metrics = [vae_r_loss, vae_kl_loss], 
                            experimental_run_tf_function=False)
+        print('Compiled.')
+        print(f'Optimizer is Adam with learning_rate={learning_rate:}')
     
     
     def train(self, 
@@ -165,7 +170,7 @@ class VariationalAutoencoder():
         callbacks_images = modules.callbacks.ImagesCallback(initial_epoch, image_periodicity, self)
         
         # ---- Callback : Learning rate scheduler
-        lr_sched = modules.callbacks.step_decay_schedule(initial_lr=self.learning_rate, decay_factor=lr_decay, step_size=1)
+        #lr_sched = modules.callbacks.step_decay_schedule(initial_lr=self.learning_rate, decay_factor=lr_decay, step_size=1)
         
         # ---- Callback : Checkpoint
         filename = self.run_directory+"/models/model-{epoch:03d}-{loss:.2f}.h5"
@@ -179,17 +184,23 @@ class VariationalAutoencoder():
         dirname = self.run_directory+"/logs"
         callback_tensorboard = TensorBoard(log_dir=dirname, histogram_freq=1)
 
-        callbacks_list = [callbacks_images, callback_chkpts, callback_bestmodel, callback_tensorboard, lr_sched]
-
-        self.model.fit(x_train[:n_train], x_train[:n_train],
-                       batch_size = batch_size,
-                       shuffle = True,
-                       epochs = epochs,
-                       initial_epoch = initial_epoch,
-                       callbacks = callbacks_list,
-                       validation_data = (x_test[:n_test], x_test[:n_test])
-                        )
-        
+        callbacks_list = [callbacks_images, callback_chkpts, callback_bestmodel, callback_tensorboard]
+
+        # ---- Let's go...
+        start_time   = time.time()
+        self.history = self.model.fit(x_train[:n_train], x_train[:n_train],
+                                      batch_size = batch_size,
+                                      shuffle = True,
+                                      epochs = epochs,
+                                      initial_epoch = initial_epoch,
+                                      callbacks = callbacks_list,
+                                      validation_data = (x_test[:n_test], x_test[:n_test])
+                                      )
+        end_time  = time.time()
+        dt  = end_time-start_time
+        dth = str(datetime.timedelta(seconds=int(dt)))
+        self.duration = dt
+        print(f'\nTrain duration : {dt:.2f} sec. - {dth:}')
         
     def plot_model(self):
         d=self.run_directory+'/figs'
@@ -198,3 +209,39 @@ class VariationalAutoencoder():
         plot_model(self.decoder, to_file=f'{d}/decoder.png', show_shapes = True, show_layer_names = True)
 
         
+    def save(self,config='vae_config.json', model='model.h5'):
+        # ---- Save config in json
+        if config!=None:
+            to_save  = ['input_shape', 'encoder_layers', 'decoder_layers', 'z_dim', 'run_tag', 'verbose']
+            data     = { i:self.__dict__[i] for i in to_save }
+            filename = self.run_directory+'/models/'+config
+            with open(filename, 'w') as outfile:
+                json.dump(data, outfile)
+            print(f'Config saved in : {filename}')
+        # ---- Save model
+        if model!=None:
+            filename = self.run_directory+'/models/'+model
+            self.model.save(filename)
+            print(f'Model saved in  : {filename}')
+
+            
+    def load_weights(self,model='model.h5'):
+        filename = self.run_directory+'/models/'+model
+        self.model.load_weights(filename)
+        print(f'Weights loaded from : {filename}')
+    
+            
+    @classmethod
+    def load(cls, run_tag='000', config='vae_config.json', model='model.h5'):
+        # ---- Instantiate a new vae
+        filename = f'./run/{run_tag}/models/{config}'
+        with open(filename, 'r') as infile:
+            params=json.load(infile)
+            print(params.keys())
+#             vae=cls( params['input_shape'], params['encoder_layers'], params['decoder_layers'], params['z_dim'], '004', 0)
+            vae=cls( **params)
+        # ---- model==None, just return it
+        if model==None: return vae
+        # ---- model!=None, get weight
+        vae.load_weights(model)
+        return vae
\ No newline at end of file
-- 
GitLab