Skip to content
Snippets Groups Projects

Update VAE with leaning rate scheduler

Former-commit-id: b5ef0e38
parent 9e4e9e88
No related branches found
No related tags found
No related merge requests found
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
Variational AutoEncoder Variational AutoEncoder
======================= =======================
--- ---
Formation Introduction au Deep Learning (FIDLE) - S. Arias, E. Maldonado, JL. Parouty - CNRS/SARI/DEVLOG - 2020 Formation Introduction au Deep Learning (FIDLE) - S. Arias, E. Maldonado, JL. Parouty - CNRS/SARI/DEVLOG - 2020
## Variational AutoEncoder (VAE), with MNIST Dataset ## Variational AutoEncoder (VAE), with MNIST Dataset
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
## Step 1 - Init python stuff ## Step 1 - Init python stuff
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
import tensorflow.keras as keras import tensorflow.keras as keras
import tensorflow.keras.datasets.mnist as mnist import tensorflow.keras.datasets.mnist as mnist
import modules.vae import modules.vae
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import matplotlib import matplotlib
import seaborn as sns import seaborn as sns
import os,sys,h5py,json import os,sys,h5py,json
from importlib import reload from importlib import reload
sys.path.append('..') sys.path.append('..')
import fidle.pwk as ooo import fidle.pwk as ooo
ooo.init() ooo.init()
``` ```
%% Output
IDLE 2020 - Practical Work Module
Version : 0.2.5
Run time : Tuesday 4 February 2020, 00:10:15
Matplotlib style : ../fidle/talk.mplstyle
TensorFlow version : 2.0.0
Keras version : 2.2.4-tf
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
## Step 2 - Get data ## Step 2 - Get data
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
(x_train, y_train), (x_test, y_test) = mnist.load_data() (x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = x_train.astype('float32') / 255. x_train = x_train.astype('float32') / 255.
x_train = np.expand_dims(x_train, axis=3) x_train = np.expand_dims(x_train, axis=3)
x_test = x_test.astype('float32') / 255. x_test = x_test.astype('float32') / 255.
x_test = np.expand_dims(x_test, axis=3) x_test = np.expand_dims(x_test, axis=3)
print(x_train.shape) print(x_train.shape)
print(x_test.shape) print(x_test.shape)
``` ```
%% Output
(60000, 28, 28, 1)
(10000, 28, 28, 1)
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
## Step 3 - Get VAE model ## Step 3 - Get VAE model
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
# reload(modules.vae) # reload(modules.vae)
# reload(modules.callbacks) # reload(modules.callbacks)
tag = '000' tag = '001'
input_shape = (28,28,1) input_shape = (28,28,1)
z_dim = 2 z_dim = 2
verbose = 0 verbose = 0
encoder= [ {'type':'Conv2D', 'filters':32, 'kernel_size':(3,3), 'strides':1, 'padding':'same', 'activation':'relu'}, encoder= [ {'type':'Conv2D', 'filters':32, 'kernel_size':(3,3), 'strides':1, 'padding':'same', 'activation':'relu'},
{'type':'Conv2D', 'filters':64, 'kernel_size':(3,3), 'strides':2, 'padding':'same', 'activation':'relu'}, {'type':'Conv2D', 'filters':64, 'kernel_size':(3,3), 'strides':2, 'padding':'same', 'activation':'relu'},
{'type':'Conv2D', 'filters':64, 'kernel_size':(3,3), 'strides':2, 'padding':'same', 'activation':'relu'}, {'type':'Conv2D', 'filters':64, 'kernel_size':(3,3), 'strides':2, 'padding':'same', 'activation':'relu'},
{'type':'Conv2D', 'filters':64, 'kernel_size':(3,3), 'strides':1, 'padding':'same', 'activation':'relu'} {'type':'Conv2D', 'filters':64, 'kernel_size':(3,3), 'strides':1, 'padding':'same', 'activation':'relu'}
] ]
decoder= [ {'type':'Conv2DT', 'filters':64, 'kernel_size':(3,3), 'strides':1, 'padding':'same', 'activation':'relu'}, decoder= [ {'type':'Conv2DT', 'filters':64, 'kernel_size':(3,3), 'strides':1, 'padding':'same', 'activation':'relu'},
{'type':'Conv2DT', 'filters':64, 'kernel_size':(3,3), 'strides':2, 'padding':'same', 'activation':'relu'}, {'type':'Conv2DT', 'filters':64, 'kernel_size':(3,3), 'strides':2, 'padding':'same', 'activation':'relu'},
{'type':'Conv2DT', 'filters':32, 'kernel_size':(3,3), 'strides':2, 'padding':'same', 'activation':'relu'}, {'type':'Conv2DT', 'filters':32, 'kernel_size':(3,3), 'strides':2, 'padding':'same', 'activation':'relu'},
{'type':'Conv2DT', 'filters':1, 'kernel_size':(3,3), 'strides':1, 'padding':'same', 'activation':'sigmoid'} {'type':'Conv2DT', 'filters':1, 'kernel_size':(3,3), 'strides':1, 'padding':'same', 'activation':'sigmoid'}
] ]
vae = modules.vae.VariationalAutoencoder(input_shape = input_shape, vae = modules.vae.VariationalAutoencoder(input_shape = input_shape,
encoder_layers = encoder, encoder_layers = encoder,
decoder_layers = decoder, decoder_layers = decoder,
z_dim = z_dim, z_dim = z_dim,
verbose = verbose, verbose = verbose,
run_tag = tag) run_tag = tag)
``` ```
%% Output
Model initialized.
Outputs will be in : ./run/001
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
## Step 4 - Compile it ## Step 4 - Compile it
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
learning_rate = 0.0005 learning_rate = 0.0005
r_loss_factor = 1000 r_loss_factor = 1000
vae.compile(learning_rate, r_loss_factor) vae.compile(learning_rate, r_loss_factor)
``` ```
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
## Step 5 - Train ## Step 5 - Train
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
batch_size = 100 batch_size = 100
epochs = 200 epochs = 200
image_periodicity = 1 # in epoch image_periodicity = 1 # for each epoch
chkpt_periodicity = 2 # in epoch chkpt_periodicity = 2 # for each epoch
initial_epoch = 0 initial_epoch = 0
dataset_size = 1 dataset_size = 1
``` ```
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
vae.train(x_train, vae.train(x_train,
x_test, x_test,
batch_size = batch_size, batch_size = batch_size,
epochs = epochs, epochs = epochs,
image_periodicity = image_periodicity, image_periodicity = image_periodicity,
chkpt_periodicity = chkpt_periodicity, chkpt_periodicity = chkpt_periodicity,
initial_epoch = initial_epoch, initial_epoch = initial_epoch,
dataset_size = dataset_size, dataset_size = dataset_size,
lr_decay = 1 lr_decay = 1
) )
``` ```
%% Output
Train on 60000 samples, validate on 10000 samples
Epoch 1/200
100/60000 [..............................] - ETA: 23:40 - loss: 231.4378 - vae_r_loss: 231.4373 - vae_kl_loss: 5.3801e-04WARNING:tensorflow:Method (on_train_batch_end) is slow compared to the batch update (0.251492). Check your callbacks.
60000/60000 [==============================] - 6s 101us/sample - loss: 67.7431 - vae_r_loss: 65.0691 - vae_kl_loss: 2.6740 - val_loss: 55.6598 - val_vae_r_loss: 52.4039 - val_vae_kl_loss: 3.2560
Epoch 2/200
60000/60000 [==============================] - 3s 58us/sample - loss: 54.0334 - vae_r_loss: 50.4695 - vae_kl_loss: 3.5639 - val_loss: 52.9105 - val_vae_r_loss: 49.1433 - val_vae_kl_loss: 3.7672
Epoch 3/200
60000/60000 [==============================] - 3s 57us/sample - loss: 51.8937 - vae_r_loss: 47.9195 - vae_kl_loss: 3.9743 - val_loss: 51.1775 - val_vae_r_loss: 47.0874 - val_vae_kl_loss: 4.0901
Epoch 4/200
60000/60000 [==============================] - 4s 59us/sample - loss: 50.4622 - vae_r_loss: 46.1359 - vae_kl_loss: 4.3264 - val_loss: 49.8507 - val_vae_r_loss: 45.2015 - val_vae_kl_loss: 4.6492
Epoch 5/200
60000/60000 [==============================] - 3s 57us/sample - loss: 49.3577 - vae_r_loss: 44.8123 - vae_kl_loss: 4.5454 - val_loss: 48.9416 - val_vae_r_loss: 44.3832 - val_vae_kl_loss: 4.5584
Epoch 6/200
60000/60000 [==============================] - 3s 58us/sample - loss: 48.5603 - vae_r_loss: 43.8800 - vae_kl_loss: 4.6803 - val_loss: 48.1800 - val_vae_r_loss: 43.5046 - val_vae_kl_loss: 4.6754
Epoch 7/200
60000/60000 [==============================] - 3s 57us/sample - loss: 48.0286 - vae_r_loss: 43.2646 - vae_kl_loss: 4.7640 - val_loss: 47.9362 - val_vae_r_loss: 43.2833 - val_vae_kl_loss: 4.6529
Epoch 8/200
60000/60000 [==============================] - 3s 58us/sample - loss: 47.6163 - vae_r_loss: 42.7828 - vae_kl_loss: 4.8336 - val_loss: 47.6161 - val_vae_r_loss: 42.7176 - val_vae_kl_loss: 4.8985
Epoch 9/200
60000/60000 [==============================] - 3s 58us/sample - loss: 47.2654 - vae_r_loss: 42.3804 - vae_kl_loss: 4.8850 - val_loss: 47.1385 - val_vae_r_loss: 42.2280 - val_vae_kl_loss: 4.9105
Epoch 10/200
WARNING:tensorflow:Method (on_train_batch_end) is slow compared to the batch update (0.732872). Check your callbacks.
100/60000 [..............................] - ETA: 7:23 - loss: 47.8688 - vae_r_loss: 43.0966 - vae_kl_loss: 4.7722WARNING:tensorflow:Method (on_train_batch_end) is slow compared to the batch update (0.366450). Check your callbacks.
60000/60000 [==============================] - 4s 70us/sample - loss: 46.9698 - vae_r_loss: 42.0353 - vae_kl_loss: 4.9345 - val_loss: 47.0246 - val_vae_r_loss: 42.1103 - val_vae_kl_loss: 4.9143
Epoch 11/200
60000/60000 [==============================] - 3s 57us/sample - loss: 46.7538 - vae_r_loss: 41.7733 - vae_kl_loss: 4.9805 - val_loss: 46.9033 - val_vae_r_loss: 41.9019 - val_vae_kl_loss: 5.0014
Epoch 12/200
60000/60000 [==============================] - 3s 58us/sample - loss: 46.4962 - vae_r_loss: 41.4867 - vae_kl_loss: 5.0095 - val_loss: 46.6990 - val_vae_r_loss: 41.8006 - val_vae_kl_loss: 4.8985
Epoch 13/200
60000/60000 [==============================] - 3s 57us/sample - loss: 46.3232 - vae_r_loss: 41.2603 - vae_kl_loss: 5.0629 - val_loss: 46.6737 - val_vae_r_loss: 41.4675 - val_vae_kl_loss: 5.2061
Epoch 14/200
60000/60000 [==============================] - 3s 58us/sample - loss: 46.1505 - vae_r_loss: 41.0678 - vae_kl_loss: 5.0828 - val_loss: 46.3871 - val_vae_r_loss: 41.4687 - val_vae_kl_loss: 4.9184
Epoch 15/200
60000/60000 [==============================] - 3s 57us/sample - loss: 45.9750 - vae_r_loss: 40.8533 - vae_kl_loss: 5.1217 - val_loss: 46.1730 - val_vae_r_loss: 41.0982 - val_vae_kl_loss: 5.0748
Epoch 16/200
60000/60000 [==============================] - 3s 57us/sample - loss: 45.8053 - vae_r_loss: 40.6467 - vae_kl_loss: 5.1586 - val_loss: 46.2439 - val_vae_r_loss: 41.1142 - val_vae_kl_loss: 5.1297
Epoch 17/200
60000/60000 [==============================] - 3s 57us/sample - loss: 45.6415 - vae_r_loss: 40.4657 - vae_kl_loss: 5.1758 - val_loss: 46.0754 - val_vae_r_loss: 41.0632 - val_vae_kl_loss: 5.0122
Epoch 18/200
60000/60000 [==============================] - 3s 58us/sample - loss: 45.5121 - vae_r_loss: 40.3147 - vae_kl_loss: 5.1974 - val_loss: 45.8663 - val_vae_r_loss: 40.5329 - val_vae_kl_loss: 5.3334
Epoch 19/200
60000/60000 [==============================] - 3s 56us/sample - loss: 45.3686 - vae_r_loss: 40.1475 - vae_kl_loss: 5.2211 - val_loss: 46.2054 - val_vae_r_loss: 41.1238 - val_vae_kl_loss: 5.0816
Epoch 20/200
60000/60000 [==============================] - 3s 58us/sample - loss: 45.2161 - vae_r_loss: 39.9703 - vae_kl_loss: 5.2458 - val_loss: 45.7448 - val_vae_r_loss: 40.6166 - val_vae_kl_loss: 5.1283
Epoch 21/200
60000/60000 [==============================] - 3s 56us/sample - loss: 45.1159 - vae_r_loss: 39.8419 - vae_kl_loss: 5.2740 - val_loss: 45.8612 - val_vae_r_loss: 40.8692 - val_vae_kl_loss: 4.9920
Epoch 22/200
60000/60000 [==============================] - 3s 57us/sample - loss: 44.9881 - vae_r_loss: 39.7023 - vae_kl_loss: 5.2857 - val_loss: 45.8085 - val_vae_r_loss: 40.2675 - val_vae_kl_loss: 5.5410
Epoch 23/200
60000/60000 [==============================] - 3s 57us/sample - loss: 44.8471 - vae_r_loss: 39.5384 - vae_kl_loss: 5.3087 - val_loss: 45.4330 - val_vae_r_loss: 40.0743 - val_vae_kl_loss: 5.3587
Epoch 24/200
60000/60000 [==============================] - 3s 58us/sample - loss: 44.7550 - vae_r_loss: 39.4362 - vae_kl_loss: 5.3188 - val_loss: 45.3320 - val_vae_r_loss: 39.9992 - val_vae_kl_loss: 5.3328
Epoch 25/200
60000/60000 [==============================] - 3s 57us/sample - loss: 44.6692 - vae_r_loss: 39.3461 - vae_kl_loss: 5.3232 - val_loss: 45.3552 - val_vae_r_loss: 40.0258 - val_vae_kl_loss: 5.3294
Epoch 26/200
60000/60000 [==============================] - 3s 58us/sample - loss: 44.5891 - vae_r_loss: 39.2333 - vae_kl_loss: 5.3558 - val_loss: 45.2681 - val_vae_r_loss: 39.9015 - val_vae_kl_loss: 5.3666
Epoch 27/200
60000/60000 [==============================] - 3s 56us/sample - loss: 44.5072 - vae_r_loss: 39.1374 - vae_kl_loss: 5.3698 - val_loss: 45.3209 - val_vae_r_loss: 39.9636 - val_vae_kl_loss: 5.3574
Epoch 28/200
60000/60000 [==============================] - 3s 58us/sample - loss: 44.4180 - vae_r_loss: 39.0149 - vae_kl_loss: 5.4031 - val_loss: 45.2435 - val_vae_r_loss: 39.7765 - val_vae_kl_loss: 5.4671
Epoch 29/200
60000/60000 [==============================] - 3s 57us/sample - loss: 44.3102 - vae_r_loss: 38.9046 - vae_kl_loss: 5.4057 - val_loss: 45.2258 - val_vae_r_loss: 39.8441 - val_vae_kl_loss: 5.3817
Epoch 30/200
60000/60000 [==============================] - 3s 58us/sample - loss: 44.2489 - vae_r_loss: 38.8299 - vae_kl_loss: 5.4190 - val_loss: 45.0044 - val_vae_r_loss: 39.6516 - val_vae_kl_loss: 5.3528
Epoch 31/200
60000/60000 [==============================] - 3s 57us/sample - loss: 44.1732 - vae_r_loss: 38.7482 - vae_kl_loss: 5.4249 - val_loss: 45.0000 - val_vae_r_loss: 39.5609 - val_vae_kl_loss: 5.4391
Epoch 32/200
60000/60000 [==============================] - 3s 58us/sample - loss: 44.0894 - vae_r_loss: 38.6580 - vae_kl_loss: 5.4314 - val_loss: 44.9769 - val_vae_r_loss: 39.5384 - val_vae_kl_loss: 5.4385
Epoch 33/200
60000/60000 [==============================] - 3s 57us/sample - loss: 44.0582 - vae_r_loss: 38.6092 - vae_kl_loss: 5.4490 - val_loss: 44.9346 - val_vae_r_loss: 39.3805 - val_vae_kl_loss: 5.5541
Epoch 34/200
60000/60000 [==============================] - 3s 57us/sample - loss: 43.9458 - vae_r_loss: 38.4818 - vae_kl_loss: 5.4640 - val_loss: 45.0624 - val_vae_r_loss: 39.5811 - val_vae_kl_loss: 5.4813
Epoch 35/200
60000/60000 [==============================] - 3s 56us/sample - loss: 43.8850 - vae_r_loss: 38.4031 - vae_kl_loss: 5.4819 - val_loss: 45.0285 - val_vae_r_loss: 39.5350 - val_vae_kl_loss: 5.4935
Epoch 36/200
60000/60000 [==============================] - 3s 58us/sample - loss: 43.8698 - vae_r_loss: 38.3779 - vae_kl_loss: 5.4918 - val_loss: 44.9170 - val_vae_r_loss: 39.5714 - val_vae_kl_loss: 5.3456
Epoch 37/200
60000/60000 [==============================] - 3s 57us/sample - loss: 43.7739 - vae_r_loss: 38.2723 - vae_kl_loss: 5.5016 - val_loss: 44.8441 - val_vae_r_loss: 39.3665 - val_vae_kl_loss: 5.4776
Epoch 38/200
60000/60000 [==============================] - 3s 57us/sample - loss: 43.7084 - vae_r_loss: 38.1933 - vae_kl_loss: 5.5151 - val_loss: 44.9233 - val_vae_r_loss: 39.5526 - val_vae_kl_loss: 5.3706
Epoch 39/200
60000/60000 [==============================] - 3s 57us/sample - loss: 43.6626 - vae_r_loss: 38.1320 - vae_kl_loss: 5.5306 - val_loss: 44.6793 - val_vae_r_loss: 39.2304 - val_vae_kl_loss: 5.4489
Epoch 40/200
60000/60000 [==============================] - 3s 58us/sample - loss: 43.5838 - vae_r_loss: 38.0592 - vae_kl_loss: 5.5246 - val_loss: 44.6130 - val_vae_r_loss: 39.0715 - val_vae_kl_loss: 5.5415
Epoch 41/200
60000/60000 [==============================] - 3s 56us/sample - loss: 43.5194 - vae_r_loss: 37.9840 - vae_kl_loss: 5.5354 - val_loss: 44.8512 - val_vae_r_loss: 39.6158 - val_vae_kl_loss: 5.2354
Epoch 42/200
60000/60000 [==============================] - 3s 57us/sample - loss: 43.5129 - vae_r_loss: 37.9786 - vae_kl_loss: 5.5343 - val_loss: 44.6991 - val_vae_r_loss: 39.2098 - val_vae_kl_loss: 5.4894
Epoch 43/200
60000/60000 [==============================] - 3s 56us/sample - loss: 43.4707 - vae_r_loss: 37.9237 - vae_kl_loss: 5.5470 - val_loss: 44.7121 - val_vae_r_loss: 39.2446 - val_vae_kl_loss: 5.4675
Epoch 44/200
60000/60000 [==============================] - 3s 57us/sample - loss: 43.3832 - vae_r_loss: 37.8227 - vae_kl_loss: 5.5604 - val_loss: 44.9172 - val_vae_r_loss: 39.3446 - val_vae_kl_loss: 5.5726
Epoch 45/200
60000/60000 [==============================] - 3s 57us/sample - loss: 43.3868 - vae_r_loss: 37.8075 - vae_kl_loss: 5.5793 - val_loss: 44.5718 - val_vae_r_loss: 39.0284 - val_vae_kl_loss: 5.5434
Epoch 46/200
60000/60000 [==============================] - 3s 57us/sample - loss: 43.2774 - vae_r_loss: 37.6953 - vae_kl_loss: 5.5821 - val_loss: 44.6954 - val_vae_r_loss: 39.1276 - val_vae_kl_loss: 5.5678
Epoch 47/200
60000/60000 [==============================] - 3s 56us/sample - loss: 43.2765 - vae_r_loss: 37.6813 - vae_kl_loss: 5.5952 - val_loss: 44.6153 - val_vae_r_loss: 38.9606 - val_vae_kl_loss: 5.6547
Epoch 48/200
60000/60000 [==============================] - 3s 58us/sample - loss: 43.2385 - vae_r_loss: 37.6431 - vae_kl_loss: 5.5954 - val_loss: 44.5508 - val_vae_r_loss: 39.0830 - val_vae_kl_loss: 5.4678
Epoch 49/200
60000/60000 [==============================] - 3s 56us/sample - loss: 43.1847 - vae_r_loss: 37.5822 - vae_kl_loss: 5.6025 - val_loss: 44.8277 - val_vae_r_loss: 39.1688 - val_vae_kl_loss: 5.6589
Epoch 50/200
60000/60000 [==============================] - 4s 58us/sample - loss: 43.1557 - vae_r_loss: 37.5533 - vae_kl_loss: 5.6024 - val_loss: 44.5082 - val_vae_r_loss: 38.9529 - val_vae_kl_loss: 5.5553
Epoch 51/200
60000/60000 [==============================] - 3s 56us/sample - loss: 43.0726 - vae_r_loss: 37.4533 - vae_kl_loss: 5.6193 - val_loss: 44.6332 - val_vae_r_loss: 38.9104 - val_vae_kl_loss: 5.7228
Epoch 52/200
60000/60000 [==============================] - 3s 57us/sample - loss: 43.1003 - vae_r_loss: 37.4708 - vae_kl_loss: 5.6295 - val_loss: 44.5279 - val_vae_r_loss: 39.0846 - val_vae_kl_loss: 5.4433
Epoch 53/200
60000/60000 [==============================] - 3s 56us/sample - loss: 43.0121 - vae_r_loss: 37.3923 - vae_kl_loss: 5.6198 - val_loss: 44.5675 - val_vae_r_loss: 38.9651 - val_vae_kl_loss: 5.6024
Epoch 54/200
60000/60000 [==============================] - 3s 57us/sample - loss: 42.9750 - vae_r_loss: 37.3273 - vae_kl_loss: 5.6477 - val_loss: 44.6084 - val_vae_r_loss: 39.0057 - val_vae_kl_loss: 5.6027
Epoch 55/200
60000/60000 [==============================] - 3s 57us/sample - loss: 42.9669 - vae_r_loss: 37.3124 - vae_kl_loss: 5.6545 - val_loss: 44.4369 - val_vae_r_loss: 38.7499 - val_vae_kl_loss: 5.6870
Epoch 56/200
60000/60000 [==============================] - 3s 57us/sample - loss: 42.9172 - vae_r_loss: 37.2666 - vae_kl_loss: 5.6506 - val_loss: 44.4817 - val_vae_r_loss: 38.8071 - val_vae_kl_loss: 5.6747
Epoch 57/200
60000/60000 [==============================] - 3s 56us/sample - loss: 42.8719 - vae_r_loss: 37.2088 - vae_kl_loss: 5.6630 - val_loss: 44.7545 - val_vae_r_loss: 39.1340 - val_vae_kl_loss: 5.6205
Epoch 58/200
60000/60000 [==============================] - 3s 57us/sample - loss: 42.8724 - vae_r_loss: 37.2070 - vae_kl_loss: 5.6654 - val_loss: 44.4428 - val_vae_r_loss: 38.8374 - val_vae_kl_loss: 5.6054
Epoch 59/200
60000/60000 [==============================] - 3s 57us/sample - loss: 42.8085 - vae_r_loss: 37.1356 - vae_kl_loss: 5.6729 - val_loss: 44.3657 - val_vae_r_loss: 38.8973 - val_vae_kl_loss: 5.4684
Epoch 60/200
60000/60000 [==============================] - 3s 57us/sample - loss: 42.7711 - vae_r_loss: 37.1025 - vae_kl_loss: 5.6687 - val_loss: 44.5526 - val_vae_r_loss: 38.7923 - val_vae_kl_loss: 5.7603
Epoch 61/200
60000/60000 [==============================] - 3s 56us/sample - loss: 42.7549 - vae_r_loss: 37.0712 - vae_kl_loss: 5.6837 - val_loss: 44.6274 - val_vae_r_loss: 39.1211 - val_vae_kl_loss: 5.5063
Epoch 62/200
60000/60000 [==============================] - 3s 57us/sample - loss: 42.7314 - vae_r_loss: 37.0368 - vae_kl_loss: 5.6946 - val_loss: 44.3828 - val_vae_r_loss: 38.8327 - val_vae_kl_loss: 5.5502
Epoch 63/200
60000/60000 [==============================] - 3s 56us/sample - loss: 42.6688 - vae_r_loss: 36.9835 - vae_kl_loss: 5.6853 - val_loss: 44.4869 - val_vae_r_loss: 38.8497 - val_vae_kl_loss: 5.6372
Epoch 64/200
60000/60000 [==============================] - 3s 57us/sample - loss: 42.6714 - vae_r_loss: 36.9633 - vae_kl_loss: 5.7080 - val_loss: 44.4562 - val_vae_r_loss: 38.7178 - val_vae_kl_loss: 5.7384
Epoch 65/200
60000/60000 [==============================] - 3s 56us/sample - loss: 42.6547 - vae_r_loss: 36.9360 - vae_kl_loss: 5.7187 - val_loss: 44.4947 - val_vae_r_loss: 38.8561 - val_vae_kl_loss: 5.6386
Epoch 66/200
60000/60000 [==============================] - 3s 57us/sample - loss: 42.5807 - vae_r_loss: 36.8625 - vae_kl_loss: 5.7182 - val_loss: 44.4270 - val_vae_r_loss: 38.7251 - val_vae_kl_loss: 5.7019
Epoch 67/200
60000/60000 [==============================] - 3s 56us/sample - loss: 42.5664 - vae_r_loss: 36.8466 - vae_kl_loss: 5.7197 - val_loss: 44.5878 - val_vae_r_loss: 38.8787 - val_vae_kl_loss: 5.7091
Epoch 68/200
60000/60000 [==============================] - 3s 57us/sample - loss: 42.5503 - vae_r_loss: 36.8269 - vae_kl_loss: 5.7235 - val_loss: 44.6236 - val_vae_r_loss: 38.8846 - val_vae_kl_loss: 5.7390
Epoch 69/200
60000/60000 [==============================] - 3s 56us/sample - loss: 42.5057 - vae_r_loss: 36.7706 - vae_kl_loss: 5.7352 - val_loss: 44.5720 - val_vae_r_loss: 38.9196 - val_vae_kl_loss: 5.6525
Epoch 70/200
60000/60000 [==============================] - 3s 57us/sample - loss: 42.4955 - vae_r_loss: 36.7553 - vae_kl_loss: 5.7402 - val_loss: 44.4059 - val_vae_r_loss: 38.8886 - val_vae_kl_loss: 5.5173
Epoch 71/200
60000/60000 [==============================] - 3s 56us/sample - loss: 42.4649 - vae_r_loss: 36.7251 - vae_kl_loss: 5.7398 - val_loss: 44.5864 - val_vae_r_loss: 38.8203 - val_vae_kl_loss: 5.7661
Epoch 72/200
60000/60000 [==============================] - 3s 58us/sample - loss: 42.4907 - vae_r_loss: 36.7440 - vae_kl_loss: 5.7467 - val_loss: 44.3493 - val_vae_r_loss: 38.6765 - val_vae_kl_loss: 5.6727
Epoch 73/200
60000/60000 [==============================] - 3s 56us/sample - loss: 42.4224 - vae_r_loss: 36.6558 - vae_kl_loss: 5.7666 - val_loss: 44.5477 - val_vae_r_loss: 38.7588 - val_vae_kl_loss: 5.7889
Epoch 74/200
43100/60000 [====================>.........] - ETA: 0s - loss: 42.3141 - vae_r_loss: 36.5576 - vae_kl_loss: 5.7565
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
``` ```
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
``` ```
......
from tensorflow.keras.callbacks import Callback from tensorflow.keras.callbacks import Callback, LearningRateScheduler
import numpy as np import numpy as np
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import os import os
...@@ -33,3 +33,14 @@ class ImagesCallback(Callback): ...@@ -33,3 +33,14 @@ class ImagesCallback(Callback):
def on_epoch_begin(self, epoch, logs={}): def on_epoch_begin(self, epoch, logs={}):
self.epoch += 1 self.epoch += 1
def step_decay_schedule(initial_lr, decay_factor=0.5, step_size=1):
'''
Wrapper function to create a LearningRateScheduler with step decay schedule.
'''
def schedule(epoch):
new_lr = initial_lr * (decay_factor ** np.floor(epoch/step_size))
return new_lr
return LearningRateScheduler(schedule)
\ No newline at end of file
...@@ -6,7 +6,7 @@ from tensorflow.keras import backend as K ...@@ -6,7 +6,7 @@ from tensorflow.keras import backend as K
from tensorflow.keras.layers import Input, Conv2D, Flatten, Dense, Conv2DTranspose, Reshape, Lambda from tensorflow.keras.layers import Input, Conv2D, Flatten, Dense, Conv2DTranspose, Reshape, Lambda
from tensorflow.keras.layers import Activation, BatchNormalization, LeakyReLU, Dropout from tensorflow.keras.layers import Activation, BatchNormalization, LeakyReLU, Dropout
from tensorflow.keras.models import Model from tensorflow.keras.models import Model
from tensorflow.keras.callbacks import ModelCheckpoint from tensorflow.keras.callbacks import ModelCheckpoint, TensorBoard
from tensorflow.keras.optimizers import Adam from tensorflow.keras.optimizers import Adam
from tensorflow.keras.utils import plot_model from tensorflow.keras.utils import plot_model
...@@ -161,18 +161,25 @@ class VariationalAutoencoder(): ...@@ -161,18 +161,25 @@ class VariationalAutoencoder():
self.n_test = n_test self.n_test = n_test
self.batch_size = batch_size self.batch_size = batch_size
# ---- Callbacks # ---- Callback : Images
images_callback = modules.callbacks.ImagesCallback(initial_epoch, image_periodicity, self) callbacks_images = modules.callbacks.ImagesCallback(initial_epoch, image_periodicity, self)
# lr_sched = step_decay_schedule(initial_lr=self.learning_rate, decay_factor=lr_decay, step_size=1) # ---- Callback : Learning rate scheduler
lr_sched = modules.callbacks.step_decay_schedule(initial_lr=self.learning_rate, decay_factor=lr_decay, step_size=1)
filename1 = self.run_directory+"/models/model-{epoch:03d}-{loss:.2f}.h5" # ---- Callback : Checkpoint
checkpoint1 = ModelCheckpoint(filename1, save_freq=n_train*chkpt_periodicity ,verbose=0) filename = self.run_directory+"/models/model-{epoch:03d}-{loss:.2f}.h5"
callback_chkpts = ModelCheckpoint(filename, save_freq=n_train*chkpt_periodicity ,verbose=0)
filename2 = self.run_directory+"/models/best_model.h5" # ---- Callback : Best model
checkpoint2 = ModelCheckpoint(filename2, save_best_only=True, mode='min',monitor='val_loss',verbose=0) filename = self.run_directory+"/models/best_model.h5"
callback_bestmodel = ModelCheckpoint(filename, save_best_only=True, mode='min',monitor='val_loss',verbose=0)
callbacks_list = [checkpoint1, checkpoint2, images_callback] # ---- Callback tensorboard
dirname = self.run_directory+"/logs"
callback_tensorboard = TensorBoard(log_dir=dirname, histogram_freq=1)
callbacks_list = [callbacks_images, callback_chkpts, callback_bestmodel, callback_tensorboard, lr_sched]
self.model.fit(x_train[:n_train], x_train[:n_train], self.model.fit(x_train[:n_train], x_train[:n_train],
batch_size = batch_size, batch_size = batch_size,
...@@ -189,3 +196,5 @@ class VariationalAutoencoder(): ...@@ -189,3 +196,5 @@ class VariationalAutoencoder():
plot_model(self.model, to_file=f'{d}/model.png', show_shapes = True, show_layer_names = True, expand_nested=True) plot_model(self.model, to_file=f'{d}/model.png', show_shapes = True, show_layer_names = True, expand_nested=True)
plot_model(self.encoder, to_file=f'{d}/encoder.png', show_shapes = True, show_layer_names = True) plot_model(self.encoder, to_file=f'{d}/encoder.png', show_shapes = True, show_layer_names = True)
plot_model(self.decoder, to_file=f'{d}/decoder.png', show_shapes = True, show_layer_names = True) plot_model(self.decoder, to_file=f'{d}/decoder.png', show_shapes = True, show_layer_names = True)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment