{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "<img width=\"800px\" src=\"../fidle/img/00-Fidle-header-01.svg\"></img>\n", "\n", "# <!-- TITLE --> [VAE1] - Variational AutoEncoder (VAE) with MNIST\n", "<!-- DESC --> Episode 1 : Model construction and Training\n", "\n", "<!-- AUTHOR : Jean-Luc Parouty (CNRS/SIMaP) -->\n", "\n", "## Objectives :\n", " - Understanding and implementing a **variational autoencoder** neurals network (VAE)\n", " - Understanding a more **advanced programming model**\n", "\n", "The calculation needs being important, it is preferable to use a very simple dataset such as MNIST to start with.\n", "\n", "## What we're going to do :\n", "\n", " - Defining a VAE model\n", " - Build the model\n", " - Train it\n", " - Follow the learning process with Tensorboard" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Step 1 - Init python stuff" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "FIDLE 2020 - Variational AutoEncoder (VAE)\n", "TensorFlow version : 2.0.0\n", "VAE version : 1.28\n" ] } ], "source": [ "import numpy as np\n", "import sys, importlib\n", "\n", "import modules.vae\n", "import modules.loader_MNIST\n", "\n", "from modules.vae import VariationalAutoencoder\n", "from modules.loader_MNIST import Loader_MNIST\n", "\n", "VariationalAutoencoder.about()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Step 2 - Get data" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Dataset loaded.\n", "Normalized.\n", "Reshaped to (60000, 28, 28, 1)\n" ] } ], "source": [ "(x_train, y_train), (x_test, y_test) = Loader_MNIST.load()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Step 3 - Get VAE model\n", "Nous allons instancier notre modèle VAE. \n", "Ce dernier est défini avec une classe python pour alléger notre code. \n", "La description de nos deux réseaux est donnée en paramètre. \n", "Notre modèle sera sauvegardé dans le dossier : ./run/<tag>" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Model initialized.\n", "Outputs will be in : ./run/MNIST.001\n", "\n", " ---------- Encoder -------------------------------------------------- \n", "\n", "Model: \"model_13\"\n", "__________________________________________________________________________________________________\n", "Layer (type) Output Shape Param # Connected to \n", "==================================================================================================\n", "encoder_input (InputLayer) [(None, 28, 28, 1)] 0 \n", "__________________________________________________________________________________________________\n", "conv2d_12 (Conv2D) (None, 28, 28, 32) 320 encoder_input[0][0] \n", "__________________________________________________________________________________________________\n", "conv2d_13 (Conv2D) (None, 14, 14, 64) 18496 conv2d_12[0][0] \n", "__________________________________________________________________________________________________\n", "conv2d_14 (Conv2D) (None, 7, 7, 64) 36928 conv2d_13[0][0] \n", "__________________________________________________________________________________________________\n", "conv2d_15 (Conv2D) (None, 7, 7, 64) 36928 conv2d_14[0][0] \n", "__________________________________________________________________________________________________\n", "flatten_3 (Flatten) (None, 3136) 0 conv2d_15[0][0] \n", "__________________________________________________________________________________________________\n", "mu (Dense) (None, 2) 6274 flatten_3[0][0] \n", "__________________________________________________________________________________________________\n", "log_var (Dense) (None, 2) 6274 flatten_3[0][0] \n", "__________________________________________________________________________________________________\n", "encoder_output (Lambda) (None, 2) 0 mu[0][0] \n", " log_var[0][0] \n", "==================================================================================================\n", "Total params: 105,220\n", "Trainable params: 105,220\n", "Non-trainable params: 0\n", "__________________________________________________________________________________________________\n", "\n", " ---------- Encoder -------------------------------------------------- \n", "\n", "Model: \"model_14\"\n", "_________________________________________________________________\n", "Layer (type) Output Shape Param # \n", "=================================================================\n", "decoder_input (InputLayer) [(None, 2)] 0 \n", "_________________________________________________________________\n", "dense_3 (Dense) (None, 3136) 9408 \n", "_________________________________________________________________\n", "reshape_3 (Reshape) (None, 7, 7, 64) 0 \n", "_________________________________________________________________\n", "conv2d_transpose_12 (Conv2DT (None, 7, 7, 64) 36928 \n", "_________________________________________________________________\n", "conv2d_transpose_13 (Conv2DT (None, 14, 14, 64) 36928 \n", "_________________________________________________________________\n", "conv2d_transpose_14 (Conv2DT (None, 28, 28, 32) 18464 \n", "_________________________________________________________________\n", "conv2d_transpose_15 (Conv2DT (None, 28, 28, 1) 289 \n", "=================================================================\n", "Total params: 102,017\n", "Trainable params: 102,017\n", "Non-trainable params: 0\n", "_________________________________________________________________\n", "Config saved in : ./run/MNIST.001/models/vae_config.json\n" ] } ], "source": [ "tag = 'MNIST.001'\n", "\n", "input_shape = (28,28,1)\n", "z_dim = 2\n", "verbose = 1\n", "\n", "encoder= [ {'type':'Conv2D', 'filters':32, 'kernel_size':(3,3), 'strides':1, 'padding':'same', 'activation':'relu'},\n", " {'type':'Conv2D', 'filters':64, 'kernel_size':(3,3), 'strides':2, 'padding':'same', 'activation':'relu'},\n", " {'type':'Conv2D', 'filters':64, 'kernel_size':(3,3), 'strides':2, 'padding':'same', 'activation':'relu'},\n", " {'type':'Conv2D', 'filters':64, 'kernel_size':(3,3), 'strides':1, 'padding':'same', 'activation':'relu'}\n", " ]\n", "\n", "decoder= [ {'type':'Conv2DTranspose', 'filters':64, 'kernel_size':(3,3), 'strides':1, 'padding':'same', 'activation':'relu'},\n", " {'type':'Conv2DTranspose', 'filters':64, 'kernel_size':(3,3), 'strides':2, 'padding':'same', 'activation':'relu'},\n", " {'type':'Conv2DTranspose', 'filters':32, 'kernel_size':(3,3), 'strides':2, 'padding':'same', 'activation':'relu'},\n", " {'type':'Conv2DTranspose', 'filters':1, 'kernel_size':(3,3), 'strides':1, 'padding':'same', 'activation':'sigmoid'}\n", " ]\n", "\n", "vae = modules.vae.VariationalAutoencoder(input_shape = input_shape, \n", " encoder_layers = encoder, \n", " decoder_layers = decoder,\n", " z_dim = z_dim, \n", " verbose = verbose,\n", " run_tag = tag)\n", "vae.save(model=None)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Step 4 - Compile it" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Compiled.\n" ] } ], "source": [ "r_loss_factor = 1000\n", "\n", "vae.compile( optimizer='adam', r_loss_factor=r_loss_factor)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Step 5 - Train" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "batch_size = 100\n", "epochs = 100\n", "initial_epoch = 0\n", "k_size = 1 # 1 mean using 100% of the dataset" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Train on 60000 samples, validate on 10000 samples\n", "Epoch 1/100\n", " 100/60000 [..............................] - ETA: 1:16:33 - loss: 231.5715 - vae_r_loss: 231.5706 - vae_kl_loss: 8.8929e-04WARNING:tensorflow:Method (on_train_batch_end) is slow compared to the batch update (0.261125). Check your callbacks.\n", "60000/60000 [==============================] - 16s 259us/sample - loss: 63.3479 - vae_r_loss: 60.5303 - vae_kl_loss: 2.8176 - val_loss: 52.8295 - val_vae_r_loss: 49.3652 - val_vae_kl_loss: 3.4643\n", "Epoch 2/100\n", "60000/60000 [==============================] - 8s 126us/sample - loss: 50.9248 - vae_r_loss: 46.7790 - vae_kl_loss: 4.1458 - val_loss: 49.8544 - val_vae_r_loss: 45.4392 - val_vae_kl_loss: 4.4152\n", "Epoch 3/100\n", "60000/60000 [==============================] - 8s 127us/sample - loss: 48.9337 - vae_r_loss: 44.4075 - vae_kl_loss: 4.5262 - val_loss: 48.1853 - val_vae_r_loss: 43.6399 - val_vae_kl_loss: 4.5454\n", "Epoch 4/100\n", "60000/60000 [==============================] - 8s 128us/sample - loss: 47.8006 - vae_r_loss: 43.0700 - vae_kl_loss: 4.7306 - val_loss: 47.6048 - val_vae_r_loss: 42.8379 - val_vae_kl_loss: 4.7669\n", "Epoch 5/100\n", "60000/60000 [==============================] - 8s 129us/sample - loss: 47.1728 - vae_r_loss: 42.3272 - vae_kl_loss: 4.8456 - val_loss: 47.1257 - val_vae_r_loss: 42.5182 - val_vae_kl_loss: 4.6075\n", "Epoch 6/100\n", "60000/60000 [==============================] - 8s 128us/sample - loss: 46.6197 - vae_r_loss: 41.6877 - vae_kl_loss: 4.9320 - val_loss: 46.6778 - val_vae_r_loss: 41.8177 - val_vae_kl_loss: 4.8601\n", "Epoch 7/100\n", "60000/60000 [==============================] - 8s 127us/sample - loss: 46.2559 - vae_r_loss: 41.2509 - vae_kl_loss: 5.0050 - val_loss: 46.8471 - val_vae_r_loss: 41.7164 - val_vae_kl_loss: 5.1308\n", "Epoch 8/100\n", "60000/60000 [==============================] - 8s 129us/sample - loss: 45.9705 - vae_r_loss: 40.9047 - vae_kl_loss: 5.0658 - val_loss: 46.1138 - val_vae_r_loss: 40.8994 - val_vae_kl_loss: 5.2144\n", "Epoch 9/100\n", "60000/60000 [==============================] - 8s 128us/sample - loss: 45.7034 - vae_r_loss: 40.5799 - vae_kl_loss: 5.1235 - val_loss: 45.9027 - val_vae_r_loss: 40.6272 - val_vae_kl_loss: 5.2755\n", "Epoch 10/100\n", "60000/60000 [==============================] - 8s 127us/sample - loss: 45.4206 - vae_r_loss: 40.2416 - vae_kl_loss: 5.1790 - val_loss: 45.8569 - val_vae_r_loss: 40.7173 - val_vae_kl_loss: 5.1396\n", "Epoch 11/100\n", "60000/60000 [==============================] - 7s 124us/sample - loss: 45.2388 - vae_r_loss: 40.0393 - vae_kl_loss: 5.1995 - val_loss: 45.5438 - val_vae_r_loss: 40.2990 - val_vae_kl_loss: 5.2448\n", "Epoch 12/100\n", "60000/60000 [==============================] - 7s 124us/sample - loss: 45.0701 - vae_r_loss: 39.8384 - vae_kl_loss: 5.2317 - val_loss: 45.1382 - val_vae_r_loss: 39.8545 - val_vae_kl_loss: 5.2838\n", "Epoch 13/100\n", "60000/60000 [==============================] - 6s 102us/sample - loss: 44.9229 - vae_r_loss: 39.6576 - vae_kl_loss: 5.2653 - val_loss: 45.2182 - val_vae_r_loss: 39.7051 - val_vae_kl_loss: 5.5130\n", "Epoch 14/100\n", "60000/60000 [==============================] - 6s 101us/sample - loss: 44.7520 - vae_r_loss: 39.4462 - vae_kl_loss: 5.3058 - val_loss: 44.9645 - val_vae_r_loss: 39.6967 - val_vae_kl_loss: 5.2678\n", "Epoch 15/100\n", "60000/60000 [==============================] - 7s 112us/sample - loss: 44.6182 - vae_r_loss: 39.2917 - vae_kl_loss: 5.3266 - val_loss: 45.1804 - val_vae_r_loss: 39.8132 - val_vae_kl_loss: 5.3673\n", "Epoch 16/100\n", "60000/60000 [==============================] - 8s 126us/sample - loss: 44.5127 - vae_r_loss: 39.1662 - vae_kl_loss: 5.3465 - val_loss: 44.7445 - val_vae_r_loss: 39.5578 - val_vae_kl_loss: 5.1867\n", "Epoch 17/100\n", "60000/60000 [==============================] - 7s 122us/sample - loss: 44.3639 - vae_r_loss: 38.9776 - vae_kl_loss: 5.3864 - val_loss: 45.0144 - val_vae_r_loss: 39.4877 - val_vae_kl_loss: 5.5267\n", "Epoch 18/100\n", "60000/60000 [==============================] - 8s 131us/sample - loss: 44.2794 - vae_r_loss: 38.8709 - vae_kl_loss: 5.4085 - val_loss: 44.7394 - val_vae_r_loss: 39.3797 - val_vae_kl_loss: 5.3597\n", "Epoch 19/100\n", "60000/60000 [==============================] - 8s 127us/sample - loss: 44.1593 - vae_r_loss: 38.7339 - vae_kl_loss: 5.4254 - val_loss: 44.8999 - val_vae_r_loss: 39.5979 - val_vae_kl_loss: 5.3020\n", "Epoch 20/100\n", "60000/60000 [==============================] - 7s 124us/sample - loss: 44.0882 - vae_r_loss: 38.6534 - vae_kl_loss: 5.4348 - val_loss: 44.5456 - val_vae_r_loss: 39.1507 - val_vae_kl_loss: 5.3949\n", "Epoch 21/100\n", "60000/60000 [==============================] - 8s 125us/sample - loss: 43.9450 - vae_r_loss: 38.4855 - vae_kl_loss: 5.4596 - val_loss: 44.6002 - val_vae_r_loss: 39.1242 - val_vae_kl_loss: 5.4760\n", "Epoch 22/100\n", "60000/60000 [==============================] - 8s 126us/sample - loss: 43.9054 - vae_r_loss: 38.4365 - vae_kl_loss: 5.4689 - val_loss: 44.7089 - val_vae_r_loss: 39.3456 - val_vae_kl_loss: 5.3633\n", "Epoch 23/100\n", "60000/60000 [==============================] - 8s 126us/sample - loss: 43.8545 - vae_r_loss: 38.3684 - vae_kl_loss: 5.4860 - val_loss: 44.4804 - val_vae_r_loss: 39.0637 - val_vae_kl_loss: 5.4167\n", "Epoch 24/100\n", "60000/60000 [==============================] - 8s 129us/sample - loss: 43.7724 - vae_r_loss: 38.2644 - vae_kl_loss: 5.5080 - val_loss: 44.2626 - val_vae_r_loss: 38.7498 - val_vae_kl_loss: 5.5128\n", "Epoch 25/100\n", "60000/60000 [==============================] - 7s 120us/sample - loss: 43.6945 - vae_r_loss: 38.1870 - vae_kl_loss: 5.5075 - val_loss: 44.5109 - val_vae_r_loss: 38.9870 - val_vae_kl_loss: 5.5240\n", "Epoch 26/100\n", "60000/60000 [==============================] - 6s 103us/sample - loss: 43.6235 - vae_r_loss: 38.0840 - vae_kl_loss: 5.5396 - val_loss: 44.4880 - val_vae_r_loss: 38.9770 - val_vae_kl_loss: 5.5110\n", "Epoch 27/100\n", "60000/60000 [==============================] - 6s 104us/sample - loss: 43.5726 - vae_r_loss: 38.0224 - vae_kl_loss: 5.5502 - val_loss: 44.3887 - val_vae_r_loss: 39.0129 - val_vae_kl_loss: 5.3758\n", "Epoch 28/100\n", "60000/60000 [==============================] - 7s 125us/sample - loss: 43.4963 - vae_r_loss: 37.9367 - vae_kl_loss: 5.5596 - val_loss: 44.2672 - val_vae_r_loss: 38.7244 - val_vae_kl_loss: 5.5427\n", "Epoch 29/100\n", "60000/60000 [==============================] - 7s 125us/sample - loss: 43.4534 - vae_r_loss: 37.8870 - vae_kl_loss: 5.5663 - val_loss: 44.2616 - val_vae_r_loss: 38.6397 - val_vae_kl_loss: 5.6219\n", "Epoch 30/100\n", "60000/60000 [==============================] - 8s 125us/sample - loss: 43.4108 - vae_r_loss: 37.8235 - vae_kl_loss: 5.5873 - val_loss: 44.0783 - val_vae_r_loss: 38.4805 - val_vae_kl_loss: 5.5978\n", "Epoch 31/100\n", "60000/60000 [==============================] - 8s 127us/sample - loss: 43.3281 - vae_r_loss: 37.7423 - vae_kl_loss: 5.5858 - val_loss: 44.2450 - val_vae_r_loss: 38.6322 - val_vae_kl_loss: 5.6128\n", "Epoch 32/100\n", "60000/60000 [==============================] - 8s 126us/sample - loss: 43.3066 - vae_r_loss: 37.6978 - vae_kl_loss: 5.6089 - val_loss: 44.1004 - val_vae_r_loss: 38.3046 - val_vae_kl_loss: 5.7958\n", "Epoch 33/100\n", "60000/60000 [==============================] - 7s 124us/sample - loss: 43.2654 - vae_r_loss: 37.6541 - vae_kl_loss: 5.6113 - val_loss: 44.0908 - val_vae_r_loss: 38.7236 - val_vae_kl_loss: 5.3672\n", "Epoch 34/100\n", "60000/60000 [==============================] - 8s 128us/sample - loss: 43.2006 - vae_r_loss: 37.5831 - vae_kl_loss: 5.6176 - val_loss: 44.3048 - val_vae_r_loss: 38.5594 - val_vae_kl_loss: 5.7453\n", "Epoch 35/100\n", "60000/60000 [==============================] - 8s 126us/sample - loss: 43.1657 - vae_r_loss: 37.5287 - vae_kl_loss: 5.6370 - val_loss: 44.3178 - val_vae_r_loss: 38.7578 - val_vae_kl_loss: 5.5600\n", "Epoch 36/100\n", "60000/60000 [==============================] - 8s 126us/sample - loss: 43.1199 - vae_r_loss: 37.4728 - vae_kl_loss: 5.6471 - val_loss: 43.9947 - val_vae_r_loss: 38.4591 - val_vae_kl_loss: 5.5356\n", "Epoch 37/100\n", "60000/60000 [==============================] - 8s 126us/sample - loss: 43.0550 - vae_r_loss: 37.4052 - vae_kl_loss: 5.6499 - val_loss: 44.1075 - val_vae_r_loss: 38.4646 - val_vae_kl_loss: 5.6429\n", "Epoch 38/100\n", "60000/60000 [==============================] - 7s 125us/sample - loss: 43.0274 - vae_r_loss: 37.3612 - vae_kl_loss: 5.6662 - val_loss: 44.1100 - val_vae_r_loss: 38.3107 - val_vae_kl_loss: 5.7994\n", "Epoch 39/100\n", "60000/60000 [==============================] - 8s 126us/sample - loss: 43.0046 - vae_r_loss: 37.3353 - vae_kl_loss: 5.6693 - val_loss: 43.9765 - val_vae_r_loss: 38.1482 - val_vae_kl_loss: 5.8284\n", "Epoch 40/100\n", "60000/60000 [==============================] - 7s 124us/sample - loss: 42.9831 - vae_r_loss: 37.2976 - vae_kl_loss: 5.6855 - val_loss: 44.1622 - val_vae_r_loss: 38.5135 - val_vae_kl_loss: 5.6488\n", "Epoch 41/100\n", "60000/60000 [==============================] - 8s 126us/sample - loss: 42.9614 - vae_r_loss: 37.2653 - vae_kl_loss: 5.6961 - val_loss: 43.9546 - val_vae_r_loss: 38.3111 - val_vae_kl_loss: 5.6435\n", "Epoch 42/100\n", "60000/60000 [==============================] - 8s 126us/sample - loss: 42.8949 - vae_r_loss: 37.1996 - vae_kl_loss: 5.6953 - val_loss: 44.0486 - val_vae_r_loss: 38.4427 - val_vae_kl_loss: 5.6059\n", "Epoch 43/100\n", "60000/60000 [==============================] - 8s 125us/sample - loss: 42.8460 - vae_r_loss: 37.1553 - vae_kl_loss: 5.6907 - val_loss: 43.9027 - val_vae_r_loss: 38.3105 - val_vae_kl_loss: 5.5921\n", "Epoch 44/100\n", "60000/60000 [==============================] - 8s 125us/sample - loss: 42.8550 - vae_r_loss: 37.1409 - vae_kl_loss: 5.7141 - val_loss: 44.0527 - val_vae_r_loss: 38.4803 - val_vae_kl_loss: 5.5724\n", "Epoch 45/100\n", "60000/60000 [==============================] - 8s 125us/sample - loss: 42.7725 - vae_r_loss: 37.0586 - vae_kl_loss: 5.7139 - val_loss: 43.9695 - val_vae_r_loss: 38.1840 - val_vae_kl_loss: 5.7855\n", "Epoch 46/100\n", "60000/60000 [==============================] - 8s 127us/sample - loss: 42.7583 - vae_r_loss: 37.0431 - vae_kl_loss: 5.7152 - val_loss: 43.8917 - val_vae_r_loss: 38.4005 - val_vae_kl_loss: 5.4912\n", "Epoch 47/100\n", "60000/60000 [==============================] - 7s 112us/sample - loss: 42.7553 - vae_r_loss: 37.0322 - vae_kl_loss: 5.7231 - val_loss: 43.8994 - val_vae_r_loss: 38.2113 - val_vae_kl_loss: 5.6880\n", "Epoch 48/100\n", "60000/60000 [==============================] - 8s 125us/sample - loss: 42.7218 - vae_r_loss: 36.9855 - vae_kl_loss: 5.7363 - val_loss: 43.6855 - val_vae_r_loss: 37.9163 - val_vae_kl_loss: 5.7693\n", "Epoch 49/100\n", "60000/60000 [==============================] - 7s 125us/sample - loss: 42.6747 - vae_r_loss: 36.9308 - vae_kl_loss: 5.7439 - val_loss: 43.9899 - val_vae_r_loss: 38.4054 - val_vae_kl_loss: 5.5844\n", "Epoch 50/100\n", "60000/60000 [==============================] - 7s 125us/sample - loss: 42.6405 - vae_r_loss: 36.8940 - vae_kl_loss: 5.7464 - val_loss: 43.9136 - val_vae_r_loss: 38.1742 - val_vae_kl_loss: 5.7394\n", "Epoch 51/100\n", "60000/60000 [==============================] - 7s 119us/sample - loss: 42.6486 - vae_r_loss: 36.8904 - vae_kl_loss: 5.7582 - val_loss: 43.7776 - val_vae_r_loss: 37.8941 - val_vae_kl_loss: 5.8834\n", "Epoch 52/100\n", "60000/60000 [==============================] - 8s 126us/sample - loss: 42.5716 - vae_r_loss: 36.8282 - vae_kl_loss: 5.7433 - val_loss: 43.7207 - val_vae_r_loss: 37.9595 - val_vae_kl_loss: 5.7611\n", "Epoch 53/100\n", "60000/60000 [==============================] - 7s 124us/sample - loss: 42.5695 - vae_r_loss: 36.8049 - vae_kl_loss: 5.7646 - val_loss: 43.8533 - val_vae_r_loss: 38.1541 - val_vae_kl_loss: 5.6993\n", "Epoch 54/100\n", "60000/60000 [==============================] - 8s 126us/sample - loss: 42.5498 - vae_r_loss: 36.7861 - vae_kl_loss: 5.7637 - val_loss: 43.9121 - val_vae_r_loss: 38.2163 - val_vae_kl_loss: 5.6958\n", "Epoch 55/100\n", "60000/60000 [==============================] - 7s 125us/sample - loss: 42.5410 - vae_r_loss: 36.7715 - vae_kl_loss: 5.7695 - val_loss: 43.9402 - val_vae_r_loss: 38.2676 - val_vae_kl_loss: 5.6726\n", "Epoch 56/100\n", "60000/60000 [==============================] - 7s 118us/sample - loss: 42.5186 - vae_r_loss: 36.7312 - vae_kl_loss: 5.7875 - val_loss: 43.8019 - val_vae_r_loss: 38.0754 - val_vae_kl_loss: 5.7266\n", "Epoch 57/100\n", "60000/60000 [==============================] - 7s 120us/sample - loss: 42.4861 - vae_r_loss: 36.6955 - vae_kl_loss: 5.7906 - val_loss: 43.7560 - val_vae_r_loss: 37.9236 - val_vae_kl_loss: 5.8325\n", "Epoch 58/100\n", "60000/60000 [==============================] - 7s 120us/sample - loss: 42.4515 - vae_r_loss: 36.6663 - vae_kl_loss: 5.7851 - val_loss: 43.8697 - val_vae_r_loss: 37.9264 - val_vae_kl_loss: 5.9433\n", "Epoch 59/100\n", "60000/60000 [==============================] - 8s 129us/sample - loss: 42.4103 - vae_r_loss: 36.6236 - vae_kl_loss: 5.7867 - val_loss: 43.8263 - val_vae_r_loss: 37.9798 - val_vae_kl_loss: 5.8465\n", "Epoch 60/100\n", "60000/60000 [==============================] - 8s 125us/sample - loss: 42.3559 - vae_r_loss: 36.5556 - vae_kl_loss: 5.8003 - val_loss: 43.9342 - val_vae_r_loss: 38.3343 - val_vae_kl_loss: 5.5999\n", "Epoch 61/100\n", "60000/60000 [==============================] - 7s 117us/sample - loss: 42.4222 - vae_r_loss: 36.6162 - vae_kl_loss: 5.8060 - val_loss: 43.7412 - val_vae_r_loss: 37.9454 - val_vae_kl_loss: 5.7958\n", "Epoch 62/100\n", "60000/60000 [==============================] - 8s 126us/sample - loss: 42.3689 - vae_r_loss: 36.5495 - vae_kl_loss: 5.8194 - val_loss: 43.6502 - val_vae_r_loss: 37.7721 - val_vae_kl_loss: 5.8781\n", "Epoch 63/100\n", "60000/60000 [==============================] - 7s 121us/sample - loss: 42.3349 - vae_r_loss: 36.5133 - vae_kl_loss: 5.8216 - val_loss: 43.8532 - val_vae_r_loss: 38.1812 - val_vae_kl_loss: 5.6720\n", "Epoch 64/100\n", "60000/60000 [==============================] - 7s 124us/sample - loss: 42.3399 - vae_r_loss: 36.5160 - vae_kl_loss: 5.8239 - val_loss: 43.7407 - val_vae_r_loss: 37.7782 - val_vae_kl_loss: 5.9625\n", "Epoch 65/100\n", "60000/60000 [==============================] - 7s 122us/sample - loss: 42.3138 - vae_r_loss: 36.4957 - vae_kl_loss: 5.8181 - val_loss: 43.7347 - val_vae_r_loss: 37.8601 - val_vae_kl_loss: 5.8746\n", "Epoch 66/100\n", "60000/60000 [==============================] - 7s 123us/sample - loss: 42.2707 - vae_r_loss: 36.4429 - vae_kl_loss: 5.8278 - val_loss: 43.6608 - val_vae_r_loss: 37.7890 - val_vae_kl_loss: 5.8719\n", "Epoch 67/100\n", "60000/60000 [==============================] - 8s 125us/sample - loss: 42.2985 - vae_r_loss: 36.4611 - vae_kl_loss: 5.8374 - val_loss: 43.6500 - val_vae_r_loss: 37.8897 - val_vae_kl_loss: 5.7603\n", "Epoch 68/100\n", "60000/60000 [==============================] - 7s 125us/sample - loss: 42.2463 - vae_r_loss: 36.4053 - vae_kl_loss: 5.8411 - val_loss: 43.8904 - val_vae_r_loss: 38.1325 - val_vae_kl_loss: 5.7579\n", "Epoch 69/100\n", "60000/60000 [==============================] - 8s 125us/sample - loss: 42.2397 - vae_r_loss: 36.4007 - vae_kl_loss: 5.8389 - val_loss: 43.7959 - val_vae_r_loss: 38.0308 - val_vae_kl_loss: 5.7651\n", "Epoch 70/100\n", "60000/60000 [==============================] - 7s 124us/sample - loss: 42.2090 - vae_r_loss: 36.3648 - vae_kl_loss: 5.8442 - val_loss: 43.6900 - val_vae_r_loss: 37.9130 - val_vae_kl_loss: 5.7771\n", "Epoch 71/100\n", "60000/60000 [==============================] - 8s 126us/sample - loss: 42.1998 - vae_r_loss: 36.3484 - vae_kl_loss: 5.8514 - val_loss: 43.6552 - val_vae_r_loss: 37.8492 - val_vae_kl_loss: 5.8060\n", "Epoch 72/100\n", "60000/60000 [==============================] - 7s 123us/sample - loss: 42.1943 - vae_r_loss: 36.3286 - vae_kl_loss: 5.8657 - val_loss: 43.6515 - val_vae_r_loss: 37.9546 - val_vae_kl_loss: 5.6969\n", "Epoch 73/100\n", "60000/60000 [==============================] - 8s 125us/sample - loss: 42.1745 - vae_r_loss: 36.3193 - vae_kl_loss: 5.8552 - val_loss: 43.8444 - val_vae_r_loss: 38.1314 - val_vae_kl_loss: 5.7130\n", "Epoch 74/100\n", "60000/60000 [==============================] - 7s 124us/sample - loss: 42.1288 - vae_r_loss: 36.2784 - vae_kl_loss: 5.8504 - val_loss: 43.8137 - val_vae_r_loss: 38.0373 - val_vae_kl_loss: 5.7764\n", "Epoch 75/100\n", "60000/60000 [==============================] - 8s 126us/sample - loss: 42.1683 - vae_r_loss: 36.3033 - vae_kl_loss: 5.8650 - val_loss: 43.6371 - val_vae_r_loss: 37.9428 - val_vae_kl_loss: 5.6942\n", "Epoch 76/100\n", "60000/60000 [==============================] - 8s 125us/sample - loss: 42.1302 - vae_r_loss: 36.2609 - vae_kl_loss: 5.8692 - val_loss: 43.8022 - val_vae_r_loss: 37.9602 - val_vae_kl_loss: 5.8420\n", "Epoch 77/100\n", "60000/60000 [==============================] - 8s 125us/sample - loss: 42.1186 - vae_r_loss: 36.2503 - vae_kl_loss: 5.8684 - val_loss: 43.6853 - val_vae_r_loss: 37.9111 - val_vae_kl_loss: 5.7742\n", "Epoch 78/100\n", "60000/60000 [==============================] - 7s 124us/sample - loss: 42.1304 - vae_r_loss: 36.2550 - vae_kl_loss: 5.8754 - val_loss: 43.7015 - val_vae_r_loss: 37.8260 - val_vae_kl_loss: 5.8755\n", "Epoch 79/100\n", "60000/60000 [==============================] - 7s 124us/sample - loss: 42.0687 - vae_r_loss: 36.1876 - vae_kl_loss: 5.8811 - val_loss: 43.6678 - val_vae_r_loss: 37.7893 - val_vae_kl_loss: 5.8785\n", "Epoch 80/100\n", "60000/60000 [==============================] - 8s 125us/sample - loss: 42.0476 - vae_r_loss: 36.1643 - vae_kl_loss: 5.8833 - val_loss: 43.6656 - val_vae_r_loss: 37.8170 - val_vae_kl_loss: 5.8487\n", "Epoch 81/100\n", "60000/60000 [==============================] - 7s 124us/sample - loss: 42.0584 - vae_r_loss: 36.1825 - vae_kl_loss: 5.8759 - val_loss: 43.6267 - val_vae_r_loss: 37.8788 - val_vae_kl_loss: 5.7480\n", "Epoch 82/100\n", "60000/60000 [==============================] - 7s 111us/sample - loss: 42.0196 - vae_r_loss: 36.1357 - vae_kl_loss: 5.8840 - val_loss: 43.7281 - val_vae_r_loss: 37.6417 - val_vae_kl_loss: 6.0864\n", "Epoch 83/100\n", "60000/60000 [==============================] - 7s 125us/sample - loss: 42.0253 - vae_r_loss: 36.1311 - vae_kl_loss: 5.8943 - val_loss: 43.6205 - val_vae_r_loss: 37.8310 - val_vae_kl_loss: 5.7895\n", "Epoch 84/100\n", "60000/60000 [==============================] - 7s 124us/sample - loss: 42.0190 - vae_r_loss: 36.1276 - vae_kl_loss: 5.8914 - val_loss: 43.6444 - val_vae_r_loss: 37.8001 - val_vae_kl_loss: 5.8443\n", "Epoch 85/100\n", "60000/60000 [==============================] - 7s 124us/sample - loss: 42.0310 - vae_r_loss: 36.1556 - vae_kl_loss: 5.8754 - val_loss: 43.7125 - val_vae_r_loss: 37.8790 - val_vae_kl_loss: 5.8336\n", "Epoch 86/100\n", "60000/60000 [==============================] - 7s 124us/sample - loss: 42.0147 - vae_r_loss: 36.1276 - vae_kl_loss: 5.8872 - val_loss: 43.7536 - val_vae_r_loss: 37.7771 - val_vae_kl_loss: 5.9764\n", "Epoch 87/100\n", "60000/60000 [==============================] - 7s 124us/sample - loss: 41.9674 - vae_r_loss: 36.0723 - vae_kl_loss: 5.8951 - val_loss: 43.6899 - val_vae_r_loss: 37.8354 - val_vae_kl_loss: 5.8545\n", "Epoch 88/100\n", "60000/60000 [==============================] - 7s 122us/sample - loss: 41.9717 - vae_r_loss: 36.0720 - vae_kl_loss: 5.8998 - val_loss: 43.6792 - val_vae_r_loss: 37.9402 - val_vae_kl_loss: 5.7390\n", "Epoch 89/100\n", "60000/60000 [==============================] - 7s 124us/sample - loss: 41.9129 - vae_r_loss: 36.0126 - vae_kl_loss: 5.9003 - val_loss: 43.6925 - val_vae_r_loss: 37.8338 - val_vae_kl_loss: 5.8587\n", "Epoch 90/100\n", "60000/60000 [==============================] - 7s 124us/sample - loss: 41.9510 - vae_r_loss: 36.0328 - vae_kl_loss: 5.9181 - val_loss: 43.7327 - val_vae_r_loss: 37.8878 - val_vae_kl_loss: 5.8448\n", "Epoch 91/100\n", "60000/60000 [==============================] - 7s 125us/sample - loss: 41.9122 - vae_r_loss: 35.9966 - vae_kl_loss: 5.9155 - val_loss: 43.7091 - val_vae_r_loss: 37.8563 - val_vae_kl_loss: 5.8527\n", "Epoch 92/100\n", "60000/60000 [==============================] - 8s 126us/sample - loss: 41.9161 - vae_r_loss: 35.9930 - vae_kl_loss: 5.9231 - val_loss: 43.7270 - val_vae_r_loss: 37.8876 - val_vae_kl_loss: 5.8393\n", "Epoch 93/100\n", "60000/60000 [==============================] - 8s 126us/sample - loss: 41.9154 - vae_r_loss: 35.9875 - vae_kl_loss: 5.9278 - val_loss: 43.6541 - val_vae_r_loss: 37.6903 - val_vae_kl_loss: 5.9639\n", "Epoch 94/100\n", "60000/60000 [==============================] - 8s 126us/sample - loss: 41.8807 - vae_r_loss: 35.9663 - vae_kl_loss: 5.9144 - val_loss: 43.7643 - val_vae_r_loss: 37.8280 - val_vae_kl_loss: 5.9363\n", "Epoch 95/100\n", "60000/60000 [==============================] - 7s 125us/sample - loss: 41.9016 - vae_r_loss: 35.9739 - vae_kl_loss: 5.9277 - val_loss: 43.8913 - val_vae_r_loss: 37.9501 - val_vae_kl_loss: 5.9412\n", "Epoch 96/100\n", "60000/60000 [==============================] - 7s 120us/sample - loss: 41.8545 - vae_r_loss: 35.9314 - vae_kl_loss: 5.9231 - val_loss: 43.7067 - val_vae_r_loss: 37.7875 - val_vae_kl_loss: 5.9192\n", "Epoch 97/100\n", "60000/60000 [==============================] - 7s 118us/sample - loss: 41.8349 - vae_r_loss: 35.9267 - vae_kl_loss: 5.9083 - val_loss: 43.7083 - val_vae_r_loss: 37.7909 - val_vae_kl_loss: 5.9173\n", "Epoch 98/100\n", "60000/60000 [==============================] - 7s 123us/sample - loss: 41.8574 - vae_r_loss: 35.9213 - vae_kl_loss: 5.9361 - val_loss: 43.6804 - val_vae_r_loss: 37.8716 - val_vae_kl_loss: 5.8088\n", "Epoch 99/100\n", "60000/60000 [==============================] - 8s 126us/sample - loss: 41.8132 - vae_r_loss: 35.8820 - vae_kl_loss: 5.9312 - val_loss: 43.5919 - val_vae_r_loss: 37.7066 - val_vae_kl_loss: 5.8853\n", "Epoch 100/100\n", "60000/60000 [==============================] - 8s 126us/sample - loss: 41.8338 - vae_r_loss: 35.9009 - vae_kl_loss: 5.9329 - val_loss: 43.6792 - val_vae_r_loss: 37.6395 - val_vae_kl_loss: 6.0397\n", "\n", "Train duration : 750.18 sec. - 0:12:30\n" ] } ], "source": [ "vae.train(x_train,\n", " x_test,\n", " batch_size = batch_size, \n", " epochs = epochs,\n", " initial_epoch = initial_epoch,\n", " k_size = k_size\n", " )" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "---\n", "<img width=\"80px\" src=\"../fidle/img/00-Fidle-logo-01.svg\"></img>" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.6" } }, "nbformat": 4, "nbformat_minor": 4 }