From bdfef613ae58f3556663f7426d224b291f7d5f0d Mon Sep 17 00:00:00 2001 From: "Jean-Luc Parouty Jean-Luc.Parouty@simap.grenoble-inp.fr" <paroutyj@f-dahu.u-ga.fr> Date: Tue, 4 Feb 2020 19:02:23 +0100 Subject: [PATCH] Add save/load functionality to VAE model --- VAE/01-VAE with MNIST.ipynb | 400 ------------------------------- VAE/01-VAE-with-MNIST.ipynb | 269 +++++++++++++++++++++ VAE/02-VAE-with-MNIST-post.ipynb | 208 ++++++++++++++++ VAE/modules/vae.py | 107 ++++++--- 4 files changed, 554 insertions(+), 430 deletions(-) delete mode 100644 VAE/01-VAE with MNIST.ipynb create mode 100644 VAE/01-VAE-with-MNIST.ipynb create mode 100644 VAE/02-VAE-with-MNIST-post.ipynb diff --git a/VAE/01-VAE with MNIST.ipynb b/VAE/01-VAE with MNIST.ipynb deleted file mode 100644 index 8eb9c45..0000000 --- a/VAE/01-VAE with MNIST.ipynb +++ /dev/null @@ -1,400 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Variational AutoEncoder\n", - "=======================\n", - "---\n", - "Formation Introduction au Deep Learning (FIDLE) - S. Arias, E. Maldonado, JL. Parouty - CNRS/SARI/DEVLOG - 2020 \n", - "\n", - "## Variational AutoEncoder (VAE), with MNIST Dataset\n", - "\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Step 1 - Init python stuff" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "IDLE 2020 - Practical Work Module\n", - " Version : 0.2.5\n", - " Run time : Tuesday 4 February 2020, 00:10:15\n", - " Matplotlib style : ../fidle/talk.mplstyle\n", - " TensorFlow version : 2.0.0\n", - " Keras version : 2.2.4-tf\n" - ] - } - ], - "source": [ - "import numpy as np\n", - "\n", - "import tensorflow as tf\n", - "import tensorflow.keras as keras\n", - "import tensorflow.keras.datasets.mnist as mnist\n", - "\n", - "import modules.vae\n", - "\n", - "import matplotlib.pyplot as plt\n", - "import matplotlib\n", - "import seaborn as sns\n", - "\n", - "import os,sys,h5py,json\n", - "\n", - "from importlib import reload\n", - "\n", - "sys.path.append('..')\n", - "import fidle.pwk as ooo\n", - "\n", - "ooo.init()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Step 2 - Get data" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "(60000, 28, 28, 1)\n", - "(10000, 28, 28, 1)\n" - ] - } - ], - "source": [ - "(x_train, y_train), (x_test, y_test) = mnist.load_data()\n", - "\n", - "x_train = x_train.astype('float32') / 255.\n", - "x_train = np.expand_dims(x_train, axis=3)\n", - "x_test = x_test.astype('float32') / 255.\n", - "x_test = np.expand_dims(x_test, axis=3)\n", - "print(x_train.shape)\n", - "print(x_test.shape)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Step 3 - Get VAE model" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Model initialized.\n", - "Outputs will be in : ./run/001\n" - ] - } - ], - "source": [ - "# reload(modules.vae)\n", - "# reload(modules.callbacks)\n", - "\n", - "tag = '001'\n", - "\n", - "input_shape = (28,28,1)\n", - "z_dim = 2\n", - "verbose = 0\n", - "\n", - "encoder= [ {'type':'Conv2D', 'filters':32, 'kernel_size':(3,3), 'strides':1, 'padding':'same', 'activation':'relu'},\n", - " {'type':'Conv2D', 'filters':64, 'kernel_size':(3,3), 'strides':2, 'padding':'same', 'activation':'relu'},\n", - " {'type':'Conv2D', 'filters':64, 'kernel_size':(3,3), 'strides':2, 'padding':'same', 'activation':'relu'},\n", - " {'type':'Conv2D', 'filters':64, 'kernel_size':(3,3), 'strides':1, 'padding':'same', 'activation':'relu'}\n", - " ]\n", - "\n", - "decoder= [ {'type':'Conv2DT', 'filters':64, 'kernel_size':(3,3), 'strides':1, 'padding':'same', 'activation':'relu'},\n", - " {'type':'Conv2DT', 'filters':64, 'kernel_size':(3,3), 'strides':2, 'padding':'same', 'activation':'relu'},\n", - " {'type':'Conv2DT', 'filters':32, 'kernel_size':(3,3), 'strides':2, 'padding':'same', 'activation':'relu'},\n", - " {'type':'Conv2DT', 'filters':1, 'kernel_size':(3,3), 'strides':1, 'padding':'same', 'activation':'sigmoid'}\n", - " ]\n", - "\n", - "vae = modules.vae.VariationalAutoencoder(input_shape = input_shape, \n", - " encoder_layers = encoder, \n", - " decoder_layers = decoder,\n", - " z_dim = z_dim, \n", - " verbose = verbose,\n", - " run_tag = tag)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Step 4 - Compile it" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [], - "source": [ - "learning_rate = 0.0005\n", - "r_loss_factor = 1000\n", - "\n", - "vae.compile(learning_rate, r_loss_factor)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Step 5 - Train" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [], - "source": [ - "batch_size = 100\n", - "epochs = 200\n", - "image_periodicity = 1 # for each epoch\n", - "chkpt_periodicity = 2 # for each epoch\n", - "initial_epoch = 0\n", - "dataset_size = 1" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Train on 60000 samples, validate on 10000 samples\n", - "Epoch 1/200\n", - " 100/60000 [..............................] - ETA: 23:40 - loss: 231.4378 - vae_r_loss: 231.4373 - vae_kl_loss: 5.3801e-04WARNING:tensorflow:Method (on_train_batch_end) is slow compared to the batch update (0.251492). Check your callbacks.\n", - "60000/60000 [==============================] - 6s 101us/sample - loss: 67.7431 - vae_r_loss: 65.0691 - vae_kl_loss: 2.6740 - val_loss: 55.6598 - val_vae_r_loss: 52.4039 - val_vae_kl_loss: 3.2560\n", - "Epoch 2/200\n", - "60000/60000 [==============================] - 3s 58us/sample - loss: 54.0334 - vae_r_loss: 50.4695 - vae_kl_loss: 3.5639 - val_loss: 52.9105 - val_vae_r_loss: 49.1433 - val_vae_kl_loss: 3.7672\n", - "Epoch 3/200\n", - "60000/60000 [==============================] - 3s 57us/sample - loss: 51.8937 - vae_r_loss: 47.9195 - vae_kl_loss: 3.9743 - val_loss: 51.1775 - val_vae_r_loss: 47.0874 - val_vae_kl_loss: 4.0901\n", - "Epoch 4/200\n", - "60000/60000 [==============================] - 4s 59us/sample - loss: 50.4622 - vae_r_loss: 46.1359 - vae_kl_loss: 4.3264 - val_loss: 49.8507 - val_vae_r_loss: 45.2015 - val_vae_kl_loss: 4.6492\n", - "Epoch 5/200\n", - "60000/60000 [==============================] - 3s 57us/sample - loss: 49.3577 - vae_r_loss: 44.8123 - vae_kl_loss: 4.5454 - val_loss: 48.9416 - val_vae_r_loss: 44.3832 - val_vae_kl_loss: 4.5584\n", - "Epoch 6/200\n", - "60000/60000 [==============================] - 3s 58us/sample - loss: 48.5603 - vae_r_loss: 43.8800 - vae_kl_loss: 4.6803 - val_loss: 48.1800 - val_vae_r_loss: 43.5046 - val_vae_kl_loss: 4.6754\n", - "Epoch 7/200\n", - "60000/60000 [==============================] - 3s 57us/sample - loss: 48.0286 - vae_r_loss: 43.2646 - vae_kl_loss: 4.7640 - val_loss: 47.9362 - val_vae_r_loss: 43.2833 - val_vae_kl_loss: 4.6529\n", - "Epoch 8/200\n", - "60000/60000 [==============================] - 3s 58us/sample - loss: 47.6163 - vae_r_loss: 42.7828 - vae_kl_loss: 4.8336 - val_loss: 47.6161 - val_vae_r_loss: 42.7176 - val_vae_kl_loss: 4.8985\n", - "Epoch 9/200\n", - "60000/60000 [==============================] - 3s 58us/sample - loss: 47.2654 - vae_r_loss: 42.3804 - vae_kl_loss: 4.8850 - val_loss: 47.1385 - val_vae_r_loss: 42.2280 - val_vae_kl_loss: 4.9105\n", - "Epoch 10/200\n", - "WARNING:tensorflow:Method (on_train_batch_end) is slow compared to the batch update (0.732872). Check your callbacks.\n", - " 100/60000 [..............................] - ETA: 7:23 - loss: 47.8688 - vae_r_loss: 43.0966 - vae_kl_loss: 4.7722WARNING:tensorflow:Method (on_train_batch_end) is slow compared to the batch update (0.366450). Check your callbacks.\n", - "60000/60000 [==============================] - 4s 70us/sample - loss: 46.9698 - vae_r_loss: 42.0353 - vae_kl_loss: 4.9345 - val_loss: 47.0246 - val_vae_r_loss: 42.1103 - val_vae_kl_loss: 4.9143\n", - "Epoch 11/200\n", - "60000/60000 [==============================] - 3s 57us/sample - loss: 46.7538 - vae_r_loss: 41.7733 - vae_kl_loss: 4.9805 - val_loss: 46.9033 - val_vae_r_loss: 41.9019 - val_vae_kl_loss: 5.0014\n", - "Epoch 12/200\n", - "60000/60000 [==============================] - 3s 58us/sample - loss: 46.4962 - vae_r_loss: 41.4867 - vae_kl_loss: 5.0095 - val_loss: 46.6990 - val_vae_r_loss: 41.8006 - val_vae_kl_loss: 4.8985\n", - "Epoch 13/200\n", - "60000/60000 [==============================] - 3s 57us/sample - loss: 46.3232 - vae_r_loss: 41.2603 - vae_kl_loss: 5.0629 - val_loss: 46.6737 - val_vae_r_loss: 41.4675 - val_vae_kl_loss: 5.2061\n", - "Epoch 14/200\n", - "60000/60000 [==============================] - 3s 58us/sample - loss: 46.1505 - vae_r_loss: 41.0678 - vae_kl_loss: 5.0828 - val_loss: 46.3871 - val_vae_r_loss: 41.4687 - val_vae_kl_loss: 4.9184\n", - "Epoch 15/200\n", - "60000/60000 [==============================] - 3s 57us/sample - loss: 45.9750 - vae_r_loss: 40.8533 - vae_kl_loss: 5.1217 - val_loss: 46.1730 - val_vae_r_loss: 41.0982 - val_vae_kl_loss: 5.0748\n", - "Epoch 16/200\n", - "60000/60000 [==============================] - 3s 57us/sample - loss: 45.8053 - vae_r_loss: 40.6467 - vae_kl_loss: 5.1586 - val_loss: 46.2439 - val_vae_r_loss: 41.1142 - val_vae_kl_loss: 5.1297\n", - "Epoch 17/200\n", - "60000/60000 [==============================] - 3s 57us/sample - loss: 45.6415 - vae_r_loss: 40.4657 - vae_kl_loss: 5.1758 - val_loss: 46.0754 - val_vae_r_loss: 41.0632 - val_vae_kl_loss: 5.0122\n", - "Epoch 18/200\n", - "60000/60000 [==============================] - 3s 58us/sample - loss: 45.5121 - vae_r_loss: 40.3147 - vae_kl_loss: 5.1974 - val_loss: 45.8663 - val_vae_r_loss: 40.5329 - val_vae_kl_loss: 5.3334\n", - "Epoch 19/200\n", - "60000/60000 [==============================] - 3s 56us/sample - loss: 45.3686 - vae_r_loss: 40.1475 - vae_kl_loss: 5.2211 - val_loss: 46.2054 - val_vae_r_loss: 41.1238 - val_vae_kl_loss: 5.0816\n", - "Epoch 20/200\n", - "60000/60000 [==============================] - 3s 58us/sample - loss: 45.2161 - vae_r_loss: 39.9703 - vae_kl_loss: 5.2458 - val_loss: 45.7448 - val_vae_r_loss: 40.6166 - val_vae_kl_loss: 5.1283\n", - "Epoch 21/200\n", - "60000/60000 [==============================] - 3s 56us/sample - loss: 45.1159 - vae_r_loss: 39.8419 - vae_kl_loss: 5.2740 - val_loss: 45.8612 - val_vae_r_loss: 40.8692 - val_vae_kl_loss: 4.9920\n", - "Epoch 22/200\n", - "60000/60000 [==============================] - 3s 57us/sample - loss: 44.9881 - vae_r_loss: 39.7023 - vae_kl_loss: 5.2857 - val_loss: 45.8085 - val_vae_r_loss: 40.2675 - val_vae_kl_loss: 5.5410\n", - "Epoch 23/200\n", - "60000/60000 [==============================] - 3s 57us/sample - loss: 44.8471 - vae_r_loss: 39.5384 - vae_kl_loss: 5.3087 - val_loss: 45.4330 - val_vae_r_loss: 40.0743 - val_vae_kl_loss: 5.3587\n", - "Epoch 24/200\n", - "60000/60000 [==============================] - 3s 58us/sample - loss: 44.7550 - vae_r_loss: 39.4362 - vae_kl_loss: 5.3188 - val_loss: 45.3320 - val_vae_r_loss: 39.9992 - val_vae_kl_loss: 5.3328\n", - "Epoch 25/200\n", - "60000/60000 [==============================] - 3s 57us/sample - loss: 44.6692 - vae_r_loss: 39.3461 - vae_kl_loss: 5.3232 - val_loss: 45.3552 - val_vae_r_loss: 40.0258 - val_vae_kl_loss: 5.3294\n", - "Epoch 26/200\n", - "60000/60000 [==============================] - 3s 58us/sample - loss: 44.5891 - vae_r_loss: 39.2333 - vae_kl_loss: 5.3558 - val_loss: 45.2681 - val_vae_r_loss: 39.9015 - val_vae_kl_loss: 5.3666\n", - "Epoch 27/200\n", - "60000/60000 [==============================] - 3s 56us/sample - loss: 44.5072 - vae_r_loss: 39.1374 - vae_kl_loss: 5.3698 - val_loss: 45.3209 - val_vae_r_loss: 39.9636 - val_vae_kl_loss: 5.3574\n", - "Epoch 28/200\n", - "60000/60000 [==============================] - 3s 58us/sample - loss: 44.4180 - vae_r_loss: 39.0149 - vae_kl_loss: 5.4031 - val_loss: 45.2435 - val_vae_r_loss: 39.7765 - val_vae_kl_loss: 5.4671\n", - "Epoch 29/200\n", - "60000/60000 [==============================] - 3s 57us/sample - loss: 44.3102 - vae_r_loss: 38.9046 - vae_kl_loss: 5.4057 - val_loss: 45.2258 - val_vae_r_loss: 39.8441 - val_vae_kl_loss: 5.3817\n", - "Epoch 30/200\n", - "60000/60000 [==============================] - 3s 58us/sample - loss: 44.2489 - vae_r_loss: 38.8299 - vae_kl_loss: 5.4190 - val_loss: 45.0044 - val_vae_r_loss: 39.6516 - val_vae_kl_loss: 5.3528\n", - "Epoch 31/200\n", - "60000/60000 [==============================] - 3s 57us/sample - loss: 44.1732 - vae_r_loss: 38.7482 - vae_kl_loss: 5.4249 - val_loss: 45.0000 - val_vae_r_loss: 39.5609 - val_vae_kl_loss: 5.4391\n", - "Epoch 32/200\n", - "60000/60000 [==============================] - 3s 58us/sample - loss: 44.0894 - vae_r_loss: 38.6580 - vae_kl_loss: 5.4314 - val_loss: 44.9769 - val_vae_r_loss: 39.5384 - val_vae_kl_loss: 5.4385\n", - "Epoch 33/200\n", - "60000/60000 [==============================] - 3s 57us/sample - loss: 44.0582 - vae_r_loss: 38.6092 - vae_kl_loss: 5.4490 - val_loss: 44.9346 - val_vae_r_loss: 39.3805 - val_vae_kl_loss: 5.5541\n", - "Epoch 34/200\n", - "60000/60000 [==============================] - 3s 57us/sample - loss: 43.9458 - vae_r_loss: 38.4818 - vae_kl_loss: 5.4640 - val_loss: 45.0624 - val_vae_r_loss: 39.5811 - val_vae_kl_loss: 5.4813\n", - "Epoch 35/200\n", - "60000/60000 [==============================] - 3s 56us/sample - loss: 43.8850 - vae_r_loss: 38.4031 - vae_kl_loss: 5.4819 - val_loss: 45.0285 - val_vae_r_loss: 39.5350 - val_vae_kl_loss: 5.4935\n", - "Epoch 36/200\n", - "60000/60000 [==============================] - 3s 58us/sample - loss: 43.8698 - vae_r_loss: 38.3779 - vae_kl_loss: 5.4918 - val_loss: 44.9170 - val_vae_r_loss: 39.5714 - val_vae_kl_loss: 5.3456\n", - "Epoch 37/200\n", - "60000/60000 [==============================] - 3s 57us/sample - loss: 43.7739 - vae_r_loss: 38.2723 - vae_kl_loss: 5.5016 - val_loss: 44.8441 - val_vae_r_loss: 39.3665 - val_vae_kl_loss: 5.4776\n", - "Epoch 38/200\n", - "60000/60000 [==============================] - 3s 57us/sample - loss: 43.7084 - vae_r_loss: 38.1933 - vae_kl_loss: 5.5151 - val_loss: 44.9233 - val_vae_r_loss: 39.5526 - val_vae_kl_loss: 5.3706\n", - "Epoch 39/200\n", - "60000/60000 [==============================] - 3s 57us/sample - loss: 43.6626 - vae_r_loss: 38.1320 - vae_kl_loss: 5.5306 - val_loss: 44.6793 - val_vae_r_loss: 39.2304 - val_vae_kl_loss: 5.4489\n", - "Epoch 40/200\n", - "60000/60000 [==============================] - 3s 58us/sample - loss: 43.5838 - vae_r_loss: 38.0592 - vae_kl_loss: 5.5246 - val_loss: 44.6130 - val_vae_r_loss: 39.0715 - val_vae_kl_loss: 5.5415\n", - "Epoch 41/200\n", - "60000/60000 [==============================] - 3s 56us/sample - loss: 43.5194 - vae_r_loss: 37.9840 - vae_kl_loss: 5.5354 - val_loss: 44.8512 - val_vae_r_loss: 39.6158 - val_vae_kl_loss: 5.2354\n", - "Epoch 42/200\n", - "60000/60000 [==============================] - 3s 57us/sample - loss: 43.5129 - vae_r_loss: 37.9786 - vae_kl_loss: 5.5343 - val_loss: 44.6991 - val_vae_r_loss: 39.2098 - val_vae_kl_loss: 5.4894\n", - "Epoch 43/200\n", - "60000/60000 [==============================] - 3s 56us/sample - loss: 43.4707 - vae_r_loss: 37.9237 - vae_kl_loss: 5.5470 - val_loss: 44.7121 - val_vae_r_loss: 39.2446 - val_vae_kl_loss: 5.4675\n", - "Epoch 44/200\n", - "60000/60000 [==============================] - 3s 57us/sample - loss: 43.3832 - vae_r_loss: 37.8227 - vae_kl_loss: 5.5604 - val_loss: 44.9172 - val_vae_r_loss: 39.3446 - val_vae_kl_loss: 5.5726\n", - "Epoch 45/200\n", - "60000/60000 [==============================] - 3s 57us/sample - loss: 43.3868 - vae_r_loss: 37.8075 - vae_kl_loss: 5.5793 - val_loss: 44.5718 - val_vae_r_loss: 39.0284 - val_vae_kl_loss: 5.5434\n", - "Epoch 46/200\n", - "60000/60000 [==============================] - 3s 57us/sample - loss: 43.2774 - vae_r_loss: 37.6953 - vae_kl_loss: 5.5821 - val_loss: 44.6954 - val_vae_r_loss: 39.1276 - val_vae_kl_loss: 5.5678\n", - "Epoch 47/200\n", - "60000/60000 [==============================] - 3s 56us/sample - loss: 43.2765 - vae_r_loss: 37.6813 - vae_kl_loss: 5.5952 - val_loss: 44.6153 - val_vae_r_loss: 38.9606 - val_vae_kl_loss: 5.6547\n", - "Epoch 48/200\n", - "60000/60000 [==============================] - 3s 58us/sample - loss: 43.2385 - vae_r_loss: 37.6431 - vae_kl_loss: 5.5954 - val_loss: 44.5508 - val_vae_r_loss: 39.0830 - val_vae_kl_loss: 5.4678\n", - "Epoch 49/200\n", - "60000/60000 [==============================] - 3s 56us/sample - loss: 43.1847 - vae_r_loss: 37.5822 - vae_kl_loss: 5.6025 - val_loss: 44.8277 - val_vae_r_loss: 39.1688 - val_vae_kl_loss: 5.6589\n", - "Epoch 50/200\n", - "60000/60000 [==============================] - 4s 58us/sample - loss: 43.1557 - vae_r_loss: 37.5533 - vae_kl_loss: 5.6024 - val_loss: 44.5082 - val_vae_r_loss: 38.9529 - val_vae_kl_loss: 5.5553\n", - "Epoch 51/200\n", - "60000/60000 [==============================] - 3s 56us/sample - loss: 43.0726 - vae_r_loss: 37.4533 - vae_kl_loss: 5.6193 - val_loss: 44.6332 - val_vae_r_loss: 38.9104 - val_vae_kl_loss: 5.7228\n", - "Epoch 52/200\n", - "60000/60000 [==============================] - 3s 57us/sample - loss: 43.1003 - vae_r_loss: 37.4708 - vae_kl_loss: 5.6295 - val_loss: 44.5279 - val_vae_r_loss: 39.0846 - val_vae_kl_loss: 5.4433\n", - "Epoch 53/200\n", - "60000/60000 [==============================] - 3s 56us/sample - loss: 43.0121 - vae_r_loss: 37.3923 - vae_kl_loss: 5.6198 - val_loss: 44.5675 - val_vae_r_loss: 38.9651 - val_vae_kl_loss: 5.6024\n", - "Epoch 54/200\n", - "60000/60000 [==============================] - 3s 57us/sample - loss: 42.9750 - vae_r_loss: 37.3273 - vae_kl_loss: 5.6477 - val_loss: 44.6084 - val_vae_r_loss: 39.0057 - val_vae_kl_loss: 5.6027\n", - "Epoch 55/200\n", - "60000/60000 [==============================] - 3s 57us/sample - loss: 42.9669 - vae_r_loss: 37.3124 - vae_kl_loss: 5.6545 - val_loss: 44.4369 - val_vae_r_loss: 38.7499 - val_vae_kl_loss: 5.6870\n", - "Epoch 56/200\n", - "60000/60000 [==============================] - 3s 57us/sample - loss: 42.9172 - vae_r_loss: 37.2666 - vae_kl_loss: 5.6506 - val_loss: 44.4817 - val_vae_r_loss: 38.8071 - val_vae_kl_loss: 5.6747\n", - "Epoch 57/200\n", - "60000/60000 [==============================] - 3s 56us/sample - loss: 42.8719 - vae_r_loss: 37.2088 - vae_kl_loss: 5.6630 - val_loss: 44.7545 - val_vae_r_loss: 39.1340 - val_vae_kl_loss: 5.6205\n", - "Epoch 58/200\n", - "60000/60000 [==============================] - 3s 57us/sample - loss: 42.8724 - vae_r_loss: 37.2070 - vae_kl_loss: 5.6654 - val_loss: 44.4428 - val_vae_r_loss: 38.8374 - val_vae_kl_loss: 5.6054\n", - "Epoch 59/200\n", - "60000/60000 [==============================] - 3s 57us/sample - loss: 42.8085 - vae_r_loss: 37.1356 - vae_kl_loss: 5.6729 - val_loss: 44.3657 - val_vae_r_loss: 38.8973 - val_vae_kl_loss: 5.4684\n", - "Epoch 60/200\n", - "60000/60000 [==============================] - 3s 57us/sample - loss: 42.7711 - vae_r_loss: 37.1025 - vae_kl_loss: 5.6687 - val_loss: 44.5526 - val_vae_r_loss: 38.7923 - val_vae_kl_loss: 5.7603\n", - "Epoch 61/200\n", - "60000/60000 [==============================] - 3s 56us/sample - loss: 42.7549 - vae_r_loss: 37.0712 - vae_kl_loss: 5.6837 - val_loss: 44.6274 - val_vae_r_loss: 39.1211 - val_vae_kl_loss: 5.5063\n", - "Epoch 62/200\n", - "60000/60000 [==============================] - 3s 57us/sample - loss: 42.7314 - vae_r_loss: 37.0368 - vae_kl_loss: 5.6946 - val_loss: 44.3828 - val_vae_r_loss: 38.8327 - val_vae_kl_loss: 5.5502\n", - "Epoch 63/200\n", - "60000/60000 [==============================] - 3s 56us/sample - loss: 42.6688 - vae_r_loss: 36.9835 - vae_kl_loss: 5.6853 - val_loss: 44.4869 - val_vae_r_loss: 38.8497 - val_vae_kl_loss: 5.6372\n", - "Epoch 64/200\n", - "60000/60000 [==============================] - 3s 57us/sample - loss: 42.6714 - vae_r_loss: 36.9633 - vae_kl_loss: 5.7080 - val_loss: 44.4562 - val_vae_r_loss: 38.7178 - val_vae_kl_loss: 5.7384\n", - "Epoch 65/200\n", - "60000/60000 [==============================] - 3s 56us/sample - loss: 42.6547 - vae_r_loss: 36.9360 - vae_kl_loss: 5.7187 - val_loss: 44.4947 - val_vae_r_loss: 38.8561 - val_vae_kl_loss: 5.6386\n", - "Epoch 66/200\n", - "60000/60000 [==============================] - 3s 57us/sample - loss: 42.5807 - vae_r_loss: 36.8625 - vae_kl_loss: 5.7182 - val_loss: 44.4270 - val_vae_r_loss: 38.7251 - val_vae_kl_loss: 5.7019\n", - "Epoch 67/200\n", - "60000/60000 [==============================] - 3s 56us/sample - loss: 42.5664 - vae_r_loss: 36.8466 - vae_kl_loss: 5.7197 - val_loss: 44.5878 - val_vae_r_loss: 38.8787 - val_vae_kl_loss: 5.7091\n", - "Epoch 68/200\n", - "60000/60000 [==============================] - 3s 57us/sample - loss: 42.5503 - vae_r_loss: 36.8269 - vae_kl_loss: 5.7235 - val_loss: 44.6236 - val_vae_r_loss: 38.8846 - val_vae_kl_loss: 5.7390\n", - "Epoch 69/200\n", - "60000/60000 [==============================] - 3s 56us/sample - loss: 42.5057 - vae_r_loss: 36.7706 - vae_kl_loss: 5.7352 - val_loss: 44.5720 - val_vae_r_loss: 38.9196 - val_vae_kl_loss: 5.6525\n", - "Epoch 70/200\n", - "60000/60000 [==============================] - 3s 57us/sample - loss: 42.4955 - vae_r_loss: 36.7553 - vae_kl_loss: 5.7402 - val_loss: 44.4059 - val_vae_r_loss: 38.8886 - val_vae_kl_loss: 5.5173\n", - "Epoch 71/200\n", - "60000/60000 [==============================] - 3s 56us/sample - loss: 42.4649 - vae_r_loss: 36.7251 - vae_kl_loss: 5.7398 - val_loss: 44.5864 - val_vae_r_loss: 38.8203 - val_vae_kl_loss: 5.7661\n", - "Epoch 72/200\n", - "60000/60000 [==============================] - 3s 58us/sample - loss: 42.4907 - vae_r_loss: 36.7440 - vae_kl_loss: 5.7467 - val_loss: 44.3493 - val_vae_r_loss: 38.6765 - val_vae_kl_loss: 5.6727\n", - "Epoch 73/200\n", - "60000/60000 [==============================] - 3s 56us/sample - loss: 42.4224 - vae_r_loss: 36.6558 - vae_kl_loss: 5.7666 - val_loss: 44.5477 - val_vae_r_loss: 38.7588 - val_vae_kl_loss: 5.7889\n", - "Epoch 74/200\n", - "43100/60000 [====================>.........] - ETA: 0s - loss: 42.3141 - vae_r_loss: 36.5576 - vae_kl_loss: 5.7565" - ] - } - ], - "source": [ - "vae.train(x_train,\n", - " x_test,\n", - " batch_size = batch_size, \n", - " epochs = epochs,\n", - " image_periodicity = image_periodicity,\n", - " chkpt_periodicity = chkpt_periodicity,\n", - " initial_epoch = initial_epoch,\n", - " dataset_size = dataset_size,\n", - " lr_decay = 1\n", - " )" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.7.5" - } - }, - "nbformat": 4, - "nbformat_minor": 4 -} diff --git a/VAE/01-VAE-with-MNIST.ipynb b/VAE/01-VAE-with-MNIST.ipynb new file mode 100644 index 0000000..ae68bca --- /dev/null +++ b/VAE/01-VAE-with-MNIST.ipynb @@ -0,0 +1,269 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Variational AutoEncoder (VAE) with MNIST\n", + "========================================\n", + "---\n", + "Formation Introduction au Deep Learning (FIDLE) - S. Arias, E. Maldonado, JL. Parouty - CNRS/SARI/DEVLOG - 2020 \n", + "\n", + "## Episode 1 - Train a model\n", + " - Defining a VAE model\n", + " - Build the model\n", + " - Train it\n", + " - Follow the learning process with Tensorboard\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Step 1 - Init python stuff" + ] + }, + { + "cell_type": "code", + "execution_count": 65, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "FIDLE 2020 - Variational AutoEncoder (VAE)\n", + "TensorFlow version : 2.0.0\n" + ] + } + ], + "source": [ + "import numpy as np\n", + "import tensorflow as tf\n", + "import tensorflow.keras.datasets.mnist as mnist\n", + "import sys, importlib\n", + "\n", + "import modules.vae\n", + "importlib.reload(modules.vae)\n", + "\n", + "print('FIDLE 2020 - Variational AutoEncoder (VAE)')\n", + "print('TensorFlow version :',tf.__version__)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Step 2 - Get data" + ] + }, + { + "cell_type": "code", + "execution_count": 66, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Dataset loaded.\n", + "x_train shape : (60000, 28, 28, 1)\n", + "x_test_shape : (10000, 28, 28, 1)\n" + ] + } + ], + "source": [ + "(x_train, y_train), (x_test, y_test) = mnist.load_data()\n", + "\n", + "x_train = x_train.astype('float32') / 255.\n", + "x_train = np.expand_dims(x_train, axis=3)\n", + "x_test = x_test.astype('float32') / 255.\n", + "x_test = np.expand_dims(x_test, axis=3)\n", + "print('Dataset loaded.')\n", + "print(f'x_train shape : {x_train.shape}\\nx_test_shape : {x_test.shape}')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Step 3 - Get VAE model" + ] + }, + { + "cell_type": "code", + "execution_count": 67, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Model initialized.\n", + "Outputs will be in : ./run/004\n" + ] + } + ], + "source": [ + "tag = '004'\n", + "\n", + "input_shape = (28,28,1)\n", + "z_dim = 2\n", + "verbose = 0\n", + "\n", + "encoder= [ {'type':'Conv2D', 'filters':32, 'kernel_size':(3,3), 'strides':1, 'padding':'same', 'activation':'relu'},\n", + " {'type':'Conv2D', 'filters':64, 'kernel_size':(3,3), 'strides':2, 'padding':'same', 'activation':'relu'},\n", + " {'type':'Conv2D', 'filters':64, 'kernel_size':(3,3), 'strides':2, 'padding':'same', 'activation':'relu'},\n", + " {'type':'Conv2D', 'filters':64, 'kernel_size':(3,3), 'strides':1, 'padding':'same', 'activation':'relu'}\n", + " ]\n", + "\n", + "decoder= [ {'type':'Conv2DT', 'filters':64, 'kernel_size':(3,3), 'strides':1, 'padding':'same', 'activation':'relu'},\n", + " {'type':'Conv2DT', 'filters':64, 'kernel_size':(3,3), 'strides':2, 'padding':'same', 'activation':'relu'},\n", + " {'type':'Conv2DT', 'filters':32, 'kernel_size':(3,3), 'strides':2, 'padding':'same', 'activation':'relu'},\n", + " {'type':'Conv2DT', 'filters':1, 'kernel_size':(3,3), 'strides':1, 'padding':'same', 'activation':'sigmoid'}\n", + " ]\n", + "\n", + "vae = modules.vae.VariationalAutoencoder(input_shape = input_shape, \n", + " encoder_layers = encoder, \n", + " decoder_layers = decoder,\n", + " z_dim = z_dim, \n", + " verbose = verbose,\n", + " run_tag = tag)" + ] + }, + { + "cell_type": "code", + "execution_count": 68, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Config saved in : ./run/004/models/vae_config.json\n", + "Model saved in : ./run/004/models/model.h5\n" + ] + } + ], + "source": [ + "vae.save()" + ] + }, + { + "cell_type": "code", + "execution_count": 69, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "dict_keys(['input_shape', 'encoder_layers', 'decoder_layers', 'z_dim', 'run_tag', 'verbose'])\n", + "Model initialized.\n", + "Outputs will be in : ./run/004\n", + "Weights loaded from : ./run/004/models/model.h5\n" + ] + } + ], + "source": [ + "vae=modules.vae.VariationalAutoencoder.load('004')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Step 4 - Compile it" + ] + }, + { + "cell_type": "code", + "execution_count": 70, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Compiled.\n", + "Optimizer is Adam with learning_rate=0.0005\n" + ] + } + ], + "source": [ + "learning_rate = 0.0005\n", + "r_loss_factor = 1000\n", + "\n", + "vae.compile(learning_rate, r_loss_factor)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Step 5 - Train" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "batch_size = 100\n", + "epochs = 100\n", + "image_periodicity = 1 # for each epoch\n", + "chkpt_periodicity = 2 # for each epoch\n", + "initial_epoch = 0\n", + "dataset_size = 1" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "vae.train(x_train,\n", + " x_test,\n", + " batch_size = batch_size, \n", + " epochs = epochs,\n", + " image_periodicity = image_periodicity,\n", + " chkpt_periodicity = chkpt_periodicity,\n", + " initial_epoch = initial_epoch,\n", + " dataset_size = dataset_size\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "vae." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.5" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/VAE/02-VAE-with-MNIST-post.ipynb b/VAE/02-VAE-with-MNIST-post.ipynb new file mode 100644 index 0000000..64e2a11 --- /dev/null +++ b/VAE/02-VAE-with-MNIST-post.ipynb @@ -0,0 +1,208 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Variational AutoEncoder (VAE) with MNIST\n", + "========================================\n", + "---\n", + "Formation Introduction au Deep Learning (FIDLE) - S. Arias, E. Maldonado, JL. Parouty - CNRS/SARI/DEVLOG - 2020 \n", + "\n", + "## Episode 2 - Analyse our trained model\n", + " - Defining a VAE model\n", + " - Build the model\n", + " - Train it\n", + " - Follow the learning process with Tensorboard\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Step 1 - Init python stuff" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "FIDLE 2020 - Variational AutoEncoder (VAE)\n", + "TensorFlow version : 2.0.0\n" + ] + } + ], + "source": [ + "import numpy as np\n", + "import tensorflow as tf\n", + "import tensorflow.keras.datasets.mnist as mnist\n", + "import sys, importlib\n", + "\n", + "import modules.vae\n", + "importlib.reload(modules.vae)\n", + "\n", + "print('FIDLE 2020 - Variational AutoEncoder (VAE)')\n", + "print('TensorFlow version :',tf.__version__)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Step 2 - Get data" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Dataset loaded.\n", + "x_train shape : (60000, 28, 28, 1)\n", + "x_test_shape : (10000, 28, 28, 1)\n" + ] + } + ], + "source": [ + "(x_train, y_train), (x_test, y_test) = mnist.load_data()\n", + "\n", + "x_train = x_train.astype('float32') / 255.\n", + "x_train = np.expand_dims(x_train, axis=3)\n", + "x_test = x_test.astype('float32') / 255.\n", + "x_test = np.expand_dims(x_test, axis=3)\n", + "print('Dataset loaded.')\n", + "print(f'x_train shape : {x_train.shape}\\nx_test_shape : {x_test.shape}')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Step 3 - Load best model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "vae\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "end_time = time.time()\n", + "dt = end_time-start_time\n", + "dth = str(datetime.timedelta(seconds=dt))\n", + "print(f'\\nTrain duration : {dt:.2f} sec. - {dth:}')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Step 4 - Compile it" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "learning_rate = 0.0005\n", + "r_loss_factor = 1000\n", + "\n", + "vae.compile(learning_rate, r_loss_factor)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Step 5 - Train" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "batch_size = 100\n", + "epochs = 200\n", + "image_periodicity = 1 # for each epoch\n", + "chkpt_periodicity = 2 # for each epoch\n", + "initial_epoch = 0\n", + "dataset_size = 1" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "vae.train(x_train,\n", + " x_test,\n", + " batch_size = batch_size, \n", + " epochs = epochs,\n", + " image_periodicity = image_periodicity,\n", + " chkpt_periodicity = chkpt_periodicity,\n", + " initial_epoch = initial_epoch,\n", + " dataset_size = dataset_size,\n", + " lr_decay = 1\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.5" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/VAE/modules/vae.py b/VAE/modules/vae.py index e98409b..207cfc1 100644 --- a/VAE/modules/vae.py +++ b/VAE/modules/vae.py @@ -13,20 +13,21 @@ from tensorflow.keras.utils import plot_model import tensorflow.keras.datasets.imdb as imdb import modules.callbacks -import os +import os, json, time, datetime class VariationalAutoencoder(): - def __init__(self, input_shape=None, encoder_layers=None, decoder_layers=None, z_dim=None, run_tag='default', verbose=0): - + def __init__(self, input_shape=None, encoder_layers=None, decoder_layers=None, z_dim=None, run_tag='000', verbose=0): + self.name = 'Variational AutoEncoder' - self.input_shape = input_shape + self.input_shape = list(input_shape) self.encoder_layers = encoder_layers self.decoder_layers = decoder_layers self.z_dim = z_dim + self.run_tag = str(run_tag) self.verbose = verbose self.run_directory = f'./run/{run_tag}' @@ -42,13 +43,14 @@ class VariationalAutoencoder(): # ---- Add next layers i=1 - for params in encoder_layers: - t=params['type'] - params.pop('type') - if t=='Conv2D': - layer = Conv2D(**params, name=f"Layer_{i}") - if t=='Dropout': - layer = Dropout(**params) + for l_config in encoder_layers: + l_type = l_config['type'] + l_params = l_config.copy() + l_params.pop('type') + if l_type=='Conv2D': + layer = Conv2D(**l_params) + if l_type=='Dropout': + layer = Dropout(**l_params) x = layer(x) i+=1 @@ -83,13 +85,14 @@ class VariationalAutoencoder(): # ---- Add next layers i=1 - for params in decoder_layers: - t=params['type'] - params.pop('type') - if t=='Conv2DT': - layer = Conv2DTranspose(**params, name=f"Layer_{i}") - if t=='Dropout': - layer = Dropout(**params) + for l_config in decoder_layers: + l_type = l_config['type'] + l_params = l_config.copy() + l_params.pop('type') + if l_type=='Conv2DT': + layer = Conv2DTranspose(**l_params) + if l_type=='Dropout': + layer = Dropout(**l_params) x = layer(x) i+=1 @@ -140,6 +143,8 @@ class VariationalAutoencoder(): loss = vae_loss, metrics = [vae_r_loss, vae_kl_loss], experimental_run_tf_function=False) + print('Compiled.') + print(f'Optimizer is Adam with learning_rate={learning_rate:}') def train(self, @@ -165,7 +170,7 @@ class VariationalAutoencoder(): callbacks_images = modules.callbacks.ImagesCallback(initial_epoch, image_periodicity, self) # ---- Callback : Learning rate scheduler - lr_sched = modules.callbacks.step_decay_schedule(initial_lr=self.learning_rate, decay_factor=lr_decay, step_size=1) + #lr_sched = modules.callbacks.step_decay_schedule(initial_lr=self.learning_rate, decay_factor=lr_decay, step_size=1) # ---- Callback : Checkpoint filename = self.run_directory+"/models/model-{epoch:03d}-{loss:.2f}.h5" @@ -179,17 +184,23 @@ class VariationalAutoencoder(): dirname = self.run_directory+"/logs" callback_tensorboard = TensorBoard(log_dir=dirname, histogram_freq=1) - callbacks_list = [callbacks_images, callback_chkpts, callback_bestmodel, callback_tensorboard, lr_sched] - - self.model.fit(x_train[:n_train], x_train[:n_train], - batch_size = batch_size, - shuffle = True, - epochs = epochs, - initial_epoch = initial_epoch, - callbacks = callbacks_list, - validation_data = (x_test[:n_test], x_test[:n_test]) - ) - + callbacks_list = [callbacks_images, callback_chkpts, callback_bestmodel, callback_tensorboard] + + # ---- Let's go... + start_time = time.time() + self.history = self.model.fit(x_train[:n_train], x_train[:n_train], + batch_size = batch_size, + shuffle = True, + epochs = epochs, + initial_epoch = initial_epoch, + callbacks = callbacks_list, + validation_data = (x_test[:n_test], x_test[:n_test]) + ) + end_time = time.time() + dt = end_time-start_time + dth = str(datetime.timedelta(seconds=int(dt))) + self.duration = dt + print(f'\nTrain duration : {dt:.2f} sec. - {dth:}') def plot_model(self): d=self.run_directory+'/figs' @@ -198,3 +209,39 @@ class VariationalAutoencoder(): plot_model(self.decoder, to_file=f'{d}/decoder.png', show_shapes = True, show_layer_names = True) + def save(self,config='vae_config.json', model='model.h5'): + # ---- Save config in json + if config!=None: + to_save = ['input_shape', 'encoder_layers', 'decoder_layers', 'z_dim', 'run_tag', 'verbose'] + data = { i:self.__dict__[i] for i in to_save } + filename = self.run_directory+'/models/'+config + with open(filename, 'w') as outfile: + json.dump(data, outfile) + print(f'Config saved in : {filename}') + # ---- Save model + if model!=None: + filename = self.run_directory+'/models/'+model + self.model.save(filename) + print(f'Model saved in : {filename}') + + + def load_weights(self,model='model.h5'): + filename = self.run_directory+'/models/'+model + self.model.load_weights(filename) + print(f'Weights loaded from : {filename}') + + + @classmethod + def load(cls, run_tag='000', config='vae_config.json', model='model.h5'): + # ---- Instantiate a new vae + filename = f'./run/{run_tag}/models/{config}' + with open(filename, 'r') as infile: + params=json.load(infile) + print(params.keys()) +# vae=cls( params['input_shape'], params['encoder_layers'], params['decoder_layers'], params['z_dim'], '004', 0) + vae=cls( **params) + # ---- model==None, just return it + if model==None: return vae + # ---- model!=None, get weight + vae.load_weights(model) + return vae \ No newline at end of file -- GitLab