Skip to content
Snippets Groups Projects

$#&@! de VAE...

parent baf57efa
No related branches found
No related tags found
No related merge requests found
%% Cell type:markdown id: tags:
Variational AutoEncoder (VAE) with MNIST
========================================
---
Formation Introduction au Deep Learning (FIDLE) - S. Arias, E. Maldonado, JL. Parouty - CNRS/SARI/DEVLOG - 2020
## Episode 1 - Train a model
- Defining a VAE model
- Build the model
- Train it
- Follow the learning process with Tensorboard
%% Cell type:markdown id: tags:
## Step 1 - Init python stuff
%% Cell type:code id: tags:
``` python
import numpy as np
import sys, importlib
import modules.vae
import modules.loader_MNIST
from modules.vae import VariationalAutoencoder
from modules.loader_MNIST import Loader_MNIST
VariationalAutoencoder.about()
```
%% Output
FIDLE 2020 - Variational AutoEncoder (VAE)
TensorFlow version : 2.0.0
VAE version : 1.24
%% Cell type:markdown id: tags:
## Step 2 - Get data
%% Cell type:code id: tags:
``` python
(x_train, y_train), (x_test, y_test) = Loader_MNIST.load()
```
%% Output
Dataset loaded.
Normalized.
Reshaped to (60000, 28, 28, 1)
%% Cell type:markdown id: tags:
## Step 3 - Get VAE model
%% Cell type:code id: tags:
``` python
tag = '001'
input_shape = (28,28,1)
z_dim = 2
verbose = 0
encoder= [ {'type':'Conv2D', 'filters':32, 'kernel_size':(3,3), 'strides':1, 'padding':'same', 'activation':'relu'},
{'type':'Conv2D', 'filters':64, 'kernel_size':(3,3), 'strides':2, 'padding':'same', 'activation':'relu'},
{'type':'Conv2D', 'filters':64, 'kernel_size':(3,3), 'strides':2, 'padding':'same', 'activation':'relu'},
{'type':'Conv2D', 'filters':64, 'kernel_size':(3,3), 'strides':1, 'padding':'same', 'activation':'relu'}
]
decoder= [ {'type':'Conv2DTranspose', 'filters':64, 'kernel_size':(3,3), 'strides':1, 'padding':'same', 'activation':'relu'},
{'type':'Conv2DTranspose', 'filters':64, 'kernel_size':(3,3), 'strides':2, 'padding':'same', 'activation':'relu'},
{'type':'Conv2DTranspose', 'filters':32, 'kernel_size':(3,3), 'strides':2, 'padding':'same', 'activation':'relu'},
{'type':'Conv2DTranspose', 'filters':1, 'kernel_size':(3,3), 'strides':1, 'padding':'same', 'activation':'sigmoid'}
]
vae = modules.vae.VariationalAutoencoder(input_shape = input_shape,
encoder_layers = encoder,
decoder_layers = decoder,
z_dim = z_dim,
verbose = verbose,
run_tag = tag)
vae.save(model=None)
```
%% Output
Model initialized.
Outputs will be in : ./run/001
Config saved in : ./run/001/models/vae_config.json
%% Cell type:markdown id: tags:
## Step 4 - Compile it
%% Cell type:code id: tags:
``` python
learning_rate = 0.0005
r_loss_factor = 1000
vae.compile(learning_rate, r_loss_factor)
```
%% Output
Compiled.
Optimizer is Adam with learning_rate=0.0005
%% Cell type:markdown id: tags:
## Step 5 - Train
%% Cell type:code id: tags:
``` python
batch_size = 100
epochs = 100
image_periodicity = 1 # for each epoch
chkpt_periodicity = 2 # for each epoch
initial_epoch = 0
dataset_size = 1 # 1 mean 100%
k_size = 1 # 1 mean using 100% of the dataset
```
%% Cell type:code id: tags:
``` python
vae.train(x_train,
x_test,
batch_size = batch_size,
epochs = epochs,
image_periodicity = image_periodicity,
chkpt_periodicity = chkpt_periodicity,
initial_epoch = initial_epoch,
dataset_size = dataset_size
k_size = k_size
)
```
%% Output
Train on 60000 samples, validate on 10000 samples
Epoch 1/100
100/60000 [..............................] - ETA: 25:44 - loss: 232.2070 - vae_r_loss: 232.2063 - vae_kl_loss: 6.6949e-04WARNING:tensorflow:Method (on_train_batch_end) is slow compared to the batch update (0.288287). Check your callbacks.
60000/60000 [==============================] - 10s 172us/sample - loss: 69.2860 - vae_r_loss: 66.7599 - vae_kl_loss: 2.5261 - val_loss: 53.5040 - val_vae_r_loss: 49.8820 - val_vae_kl_loss: 3.6220
Epoch 2/100
60000/60000 [==============================] - 7s 124us/sample - loss: 51.8624 - vae_r_loss: 47.9126 - vae_kl_loss: 3.9498 - val_loss: 50.5979 - val_vae_r_loss: 46.4222 - val_vae_kl_loss: 4.1757
Epoch 3/100
60000/60000 [==============================] - 7s 123us/sample - loss: 49.8990 - vae_r_loss: 45.5683 - vae_kl_loss: 4.3307 - val_loss: 49.3813 - val_vae_r_loss: 44.9857 - val_vae_kl_loss: 4.3957
Epoch 4/100
60000/60000 [==============================] - 8s 132us/sample - loss: 48.8134 - vae_r_loss: 44.2657 - vae_kl_loss: 4.5477 - val_loss: 48.3869 - val_vae_r_loss: 43.7678 - val_vae_kl_loss: 4.6190
Epoch 5/100
60000/60000 [==============================] - 8s 125us/sample - loss: 48.0753 - vae_r_loss: 43.4002 - vae_kl_loss: 4.6751 - val_loss: 47.8571 - val_vae_r_loss: 43.3078 - val_vae_kl_loss: 4.5493
Epoch 6/100
60000/60000 [==============================] - 8s 128us/sample - loss: 47.5525 - vae_r_loss: 42.7672 - vae_kl_loss: 4.7853 - val_loss: 47.4546 - val_vae_r_loss: 42.5697 - val_vae_kl_loss: 4.8849
Epoch 7/100
60000/60000 [==============================] - 8s 130us/sample - loss: 47.1005 - vae_r_loss: 42.2315 - vae_kl_loss: 4.8690 - val_loss: 47.3139 - val_vae_r_loss: 42.4876 - val_vae_kl_loss: 4.8263
Epoch 8/100
60000/60000 [==============================] - 8s 132us/sample - loss: 46.7219 - vae_r_loss: 41.8030 - vae_kl_loss: 4.9189 - val_loss: 46.8271 - val_vae_r_loss: 42.0297 - val_vae_kl_loss: 4.7974
Epoch 9/100
60000/60000 [==============================] - 8s 130us/sample - loss: 46.4398 - vae_r_loss: 41.4570 - vae_kl_loss: 4.9827 - val_loss: 46.3167 - val_vae_r_loss: 41.2337 - val_vae_kl_loss: 5.0830
Epoch 10/100
60000/60000 [==============================] - 8s 128us/sample - loss: 46.1623 - vae_r_loss: 41.1249 - vae_kl_loss: 5.0374 - val_loss: 46.1499 - val_vae_r_loss: 41.0398 - val_vae_kl_loss: 5.1101
Epoch 11/100
60000/60000 [==============================] - 7s 122us/sample - loss: 45.9183 - vae_r_loss: 40.8327 - vae_kl_loss: 5.0856 - val_loss: 45.9335 - val_vae_r_loss: 40.9553 - val_vae_kl_loss: 4.9782
Epoch 12/100
60000/60000 [==============================] - 8s 131us/sample - loss: 45.6792 - vae_r_loss: 40.5563 - vae_kl_loss: 5.1229 - val_loss: 45.9237 - val_vae_r_loss: 40.6965 - val_vae_kl_loss: 5.2272
Epoch 13/100
60000/60000 [==============================] - 8s 127us/sample - loss: 45.5282 - vae_r_loss: 40.3813 - vae_kl_loss: 5.1469 - val_loss: 45.6323 - val_vae_r_loss: 40.4588 - val_vae_kl_loss: 5.1734
Epoch 14/100
60000/60000 [==============================] - 8s 127us/sample - loss: 45.2921 - vae_r_loss: 40.1040 - vae_kl_loss: 5.1882 - val_loss: 45.8203 - val_vae_r_loss: 40.7723 - val_vae_kl_loss: 5.0480
Epoch 15/100
60000/60000 [==============================] - 8s 132us/sample - loss: 45.1484 - vae_r_loss: 39.9344 - vae_kl_loss: 5.2140 - val_loss: 45.3385 - val_vae_r_loss: 40.1524 - val_vae_kl_loss: 5.1861
Epoch 16/100
60000/60000 [==============================] - 8s 130us/sample - loss: 45.0287 - vae_r_loss: 39.7829 - vae_kl_loss: 5.2458 - val_loss: 45.2159 - val_vae_r_loss: 40.0134 - val_vae_kl_loss: 5.2025
Epoch 17/100
60000/60000 [==============================] - 8s 131us/sample - loss: 44.8389 - vae_r_loss: 39.5743 - vae_kl_loss: 5.2646 - val_loss: 45.1018 - val_vae_r_loss: 39.8287 - val_vae_kl_loss: 5.2731
Epoch 18/100
60000/60000 [==============================] - 8s 128us/sample - loss: 44.7559 - vae_r_loss: 39.4709 - vae_kl_loss: 5.2850 - val_loss: 45.0585 - val_vae_r_loss: 39.4544 - val_vae_kl_loss: 5.6040
Epoch 19/100
60000/60000 [==============================] - 8s 129us/sample - loss: 44.6301 - vae_r_loss: 39.3133 - vae_kl_loss: 5.3168 - val_loss: 45.0805 - val_vae_r_loss: 39.9210 - val_vae_kl_loss: 5.1595
Epoch 20/100
60000/60000 [==============================] - 8s 131us/sample - loss: 44.5006 - vae_r_loss: 39.1616 - vae_kl_loss: 5.3390 - val_loss: 44.9545 - val_vae_r_loss: 39.5154 - val_vae_kl_loss: 5.4391
Epoch 21/100
60000/60000 [==============================] - 7s 122us/sample - loss: 44.4012 - vae_r_loss: 39.0479 - vae_kl_loss: 5.3533 - val_loss: 45.0373 - val_vae_r_loss: 39.5811 - val_vae_kl_loss: 5.4562
Epoch 22/100
60000/60000 [==============================] - 8s 125us/sample - loss: 44.3708 - vae_r_loss: 38.9988 - vae_kl_loss: 5.3721 - val_loss: 44.8998 - val_vae_r_loss: 39.5490 - val_vae_kl_loss: 5.3509
Epoch 23/100
60000/60000 [==============================] - 8s 128us/sample - loss: 44.1875 - vae_r_loss: 38.7919 - vae_kl_loss: 5.3956 - val_loss: 44.7653 - val_vae_r_loss: 39.4184 - val_vae_kl_loss: 5.3469
Epoch 24/100
60000/60000 [==============================] - 8s 127us/sample - loss: 44.0801 - vae_r_loss: 38.6846 - vae_kl_loss: 5.3956 - val_loss: 44.5562 - val_vae_r_loss: 39.1243 - val_vae_kl_loss: 5.4318
Epoch 25/100
60000/60000 [==============================] - 8s 130us/sample - loss: 44.0227 - vae_r_loss: 38.5986 - vae_kl_loss: 5.4242 - val_loss: 44.7977 - val_vae_r_loss: 39.4418 - val_vae_kl_loss: 5.3559
Epoch 26/100
60000/60000 [==============================] - 8s 126us/sample - loss: 43.9873 - vae_r_loss: 38.5459 - vae_kl_loss: 5.4415 - val_loss: 44.6459 - val_vae_r_loss: 39.1856 - val_vae_kl_loss: 5.4603
Epoch 27/100
60000/60000 [==============================] - 7s 125us/sample - loss: 43.8557 - vae_r_loss: 38.4107 - vae_kl_loss: 5.4450 - val_loss: 44.8198 - val_vae_r_loss: 39.2364 - val_vae_kl_loss: 5.5834
Epoch 28/100
60000/60000 [==============================] - 8s 130us/sample - loss: 43.8490 - vae_r_loss: 38.3700 - vae_kl_loss: 5.4790 - val_loss: 44.5917 - val_vae_r_loss: 39.3252 - val_vae_kl_loss: 5.2665
Epoch 29/100
60000/60000 [==============================] - 8s 126us/sample - loss: 43.7319 - vae_r_loss: 38.2444 - vae_kl_loss: 5.4876 - val_loss: 44.4550 - val_vae_r_loss: 38.9941 - val_vae_kl_loss: 5.4609
Epoch 30/100
60000/60000 [==============================] - 8s 129us/sample - loss: 43.6520 - vae_r_loss: 38.1524 - vae_kl_loss: 5.4996 - val_loss: 44.3490 - val_vae_r_loss: 38.6813 - val_vae_kl_loss: 5.6678
Epoch 31/100
60000/60000 [==============================] - 8s 125us/sample - loss: 43.6014 - vae_r_loss: 38.1041 - vae_kl_loss: 5.4973 - val_loss: 44.3030 - val_vae_r_loss: 38.9108 - val_vae_kl_loss: 5.3923
Epoch 32/100
60000/60000 [==============================] - 8s 128us/sample - loss: 43.5550 - vae_r_loss: 38.0354 - vae_kl_loss: 5.5195 - val_loss: 44.5968 - val_vae_r_loss: 39.3021 - val_vae_kl_loss: 5.2947
Epoch 33/100
60000/60000 [==============================] - 8s 129us/sample - loss: 43.4544 - vae_r_loss: 37.9230 - vae_kl_loss: 5.5314 - val_loss: 44.4347 - val_vae_r_loss: 39.0158 - val_vae_kl_loss: 5.4189
Epoch 34/100
WARNING:tensorflow:Method (on_train_batch_end) is slow compared to the batch update (0.598076). Check your callbacks.
100/60000 [..............................] - ETA: 6:06 - loss: 43.3305 - vae_r_loss: 37.8705 - vae_kl_loss: 5.4600WARNING:tensorflow:Method (on_train_batch_end) is slow compared to the batch update (0.299070). Check your callbacks.
60000/60000 [==============================] - 8s 133us/sample - loss: 43.4399 - vae_r_loss: 37.9190 - vae_kl_loss: 5.5209 - val_loss: 44.8390 - val_vae_r_loss: 39.5000 - val_vae_kl_loss: 5.3391
Epoch 35/100
60000/60000 [==============================] - 8s 129us/sample - loss: 43.3421 - vae_r_loss: 37.8032 - vae_kl_loss: 5.5389 - val_loss: 44.2497 - val_vae_r_loss: 38.7430 - val_vae_kl_loss: 5.5067
Epoch 36/100
60000/60000 [==============================] - 8s 135us/sample - loss: 43.2754 - vae_r_loss: 37.7261 - vae_kl_loss: 5.5493 - val_loss: 44.1556 - val_vae_r_loss: 38.8415 - val_vae_kl_loss: 5.3141
Epoch 37/100
60000/60000 [==============================] - 8s 130us/sample - loss: 43.2858 - vae_r_loss: 37.7115 - vae_kl_loss: 5.5743 - val_loss: 44.2977 - val_vae_r_loss: 38.7636 - val_vae_kl_loss: 5.5341
Epoch 38/100
60000/60000 [==============================] - 8s 132us/sample - loss: 43.1697 - vae_r_loss: 37.6031 - vae_kl_loss: 5.5666 - val_loss: 44.2753 - val_vae_r_loss: 38.8988 - val_vae_kl_loss: 5.3765
Epoch 39/100
60000/60000 [==============================] - 8s 129us/sample - loss: 43.1452 - vae_r_loss: 37.5476 - vae_kl_loss: 5.5976 - val_loss: 44.5351 - val_vae_r_loss: 38.8169 - val_vae_kl_loss: 5.7182
Epoch 40/100
60000/60000 [==============================] - 8s 128us/sample - loss: 43.0643 - vae_r_loss: 37.4672 - vae_kl_loss: 5.5971 - val_loss: 44.0204 - val_vae_r_loss: 38.5774 - val_vae_kl_loss: 5.4431
Epoch 41/100
60000/60000 [==============================] - 8s 128us/sample - loss: 43.0277 - vae_r_loss: 37.4318 - vae_kl_loss: 5.5959 - val_loss: 43.9885 - val_vae_r_loss: 38.5233 - val_vae_kl_loss: 5.4652
Epoch 42/100
60000/60000 [==============================] - 8s 128us/sample - loss: 42.9816 - vae_r_loss: 37.3741 - vae_kl_loss: 5.6075 - val_loss: 44.1357 - val_vae_r_loss: 38.5023 - val_vae_kl_loss: 5.6334
Epoch 43/100
60000/60000 [==============================] - 8s 127us/sample - loss: 42.9420 - vae_r_loss: 37.3163 - vae_kl_loss: 5.6258 - val_loss: 44.2125 - val_vae_r_loss: 38.6341 - val_vae_kl_loss: 5.5783
Epoch 44/100
60000/60000 [==============================] - 7s 122us/sample - loss: 42.9049 - vae_r_loss: 37.2851 - vae_kl_loss: 5.6198 - val_loss: 44.0663 - val_vae_r_loss: 38.4909 - val_vae_kl_loss: 5.5755
Epoch 45/100
60000/60000 [==============================] - 7s 118us/sample - loss: 42.8474 - vae_r_loss: 37.2118 - vae_kl_loss: 5.6356 - val_loss: 44.1876 - val_vae_r_loss: 38.5833 - val_vae_kl_loss: 5.6043
Epoch 46/100
60000/60000 [==============================] - 8s 130us/sample - loss: 42.8250 - vae_r_loss: 37.1782 - vae_kl_loss: 5.6468 - val_loss: 44.0636 - val_vae_r_loss: 38.4266 - val_vae_kl_loss: 5.6371
Epoch 47/100
60000/60000 [==============================] - 7s 125us/sample - loss: 42.7751 - vae_r_loss: 37.1261 - vae_kl_loss: 5.6489 - val_loss: 44.0862 - val_vae_r_loss: 38.3833 - val_vae_kl_loss: 5.7029
Epoch 48/100
60000/60000 [==============================] - 8s 129us/sample - loss: 42.7732 - vae_r_loss: 37.1227 - vae_kl_loss: 5.6504 - val_loss: 44.1535 - val_vae_r_loss: 38.5943 - val_vae_kl_loss: 5.5592
Epoch 49/100
60000/60000 [==============================] - 8s 125us/sample - loss: 42.6997 - vae_r_loss: 37.0276 - vae_kl_loss: 5.6721 - val_loss: 44.0287 - val_vae_r_loss: 38.2934 - val_vae_kl_loss: 5.7352
Epoch 50/100
60000/60000 [==============================] - 8s 127us/sample - loss: 42.7032 - vae_r_loss: 37.0320 - vae_kl_loss: 5.6712 - val_loss: 44.1065 - val_vae_r_loss: 38.4361 - val_vae_kl_loss: 5.6704
Epoch 51/100
60000/60000 [==============================] - 8s 131us/sample - loss: 42.6594 - vae_r_loss: 36.9766 - vae_kl_loss: 5.6829 - val_loss: 43.9345 - val_vae_r_loss: 38.4122 - val_vae_kl_loss: 5.5223
Epoch 52/100
60000/60000 [==============================] - 8s 126us/sample - loss: 42.6146 - vae_r_loss: 36.9325 - vae_kl_loss: 5.6821 - val_loss: 43.9331 - val_vae_r_loss: 38.2921 - val_vae_kl_loss: 5.6410
Epoch 53/100
60000/60000 [==============================] - 8s 130us/sample - loss: 42.5738 - vae_r_loss: 36.8788 - vae_kl_loss: 5.6949 - val_loss: 43.9100 - val_vae_r_loss: 38.0716 - val_vae_kl_loss: 5.8385
Epoch 54/100
60000/60000 [==============================] - 8s 132us/sample - loss: 42.5322 - vae_r_loss: 36.8424 - vae_kl_loss: 5.6898 - val_loss: 43.8390 - val_vae_r_loss: 38.1401 - val_vae_kl_loss: 5.6989
Epoch 55/100
60000/60000 [==============================] - 8s 127us/sample - loss: 42.5459 - vae_r_loss: 36.8329 - vae_kl_loss: 5.7130 - val_loss: 43.8115 - val_vae_r_loss: 38.1122 - val_vae_kl_loss: 5.6993
Epoch 56/100
60000/60000 [==============================] - 8s 128us/sample - loss: 42.5097 - vae_r_loss: 36.7918 - vae_kl_loss: 5.7178 - val_loss: 44.2181 - val_vae_r_loss: 38.4090 - val_vae_kl_loss: 5.8091
Epoch 57/100
60000/60000 [==============================] - 7s 122us/sample - loss: 42.4578 - vae_r_loss: 36.7430 - vae_kl_loss: 5.7148 - val_loss: 43.8409 - val_vae_r_loss: 38.0498 - val_vae_kl_loss: 5.7912
Epoch 58/100
60000/60000 [==============================] - 8s 129us/sample - loss: 42.4207 - vae_r_loss: 36.6978 - vae_kl_loss: 5.7229 - val_loss: 44.0166 - val_vae_r_loss: 38.3827 - val_vae_kl_loss: 5.6339
Epoch 59/100
60000/60000 [==============================] - 7s 122us/sample - loss: 42.3806 - vae_r_loss: 36.6549 - vae_kl_loss: 5.7257 - val_loss: 43.7744 - val_vae_r_loss: 37.9962 - val_vae_kl_loss: 5.7782
Epoch 60/100
60000/60000 [==============================] - 8s 126us/sample - loss: 42.3547 - vae_r_loss: 36.6247 - vae_kl_loss: 5.7300 - val_loss: 43.7177 - val_vae_r_loss: 38.0379 - val_vae_kl_loss: 5.6798
Epoch 61/100
60000/60000 [==============================] - 8s 129us/sample - loss: 42.3384 - vae_r_loss: 36.5944 - vae_kl_loss: 5.7440 - val_loss: 43.8315 - val_vae_r_loss: 38.1023 - val_vae_kl_loss: 5.7292
Epoch 62/100
60000/60000 [==============================] - 7s 124us/sample - loss: 42.3395 - vae_r_loss: 36.5870 - vae_kl_loss: 5.7525 - val_loss: 44.0155 - val_vae_r_loss: 38.5374 - val_vae_kl_loss: 5.4781
Epoch 63/100
60000/60000 [==============================] - 8s 125us/sample - loss: 42.3000 - vae_r_loss: 36.5576 - vae_kl_loss: 5.7423 - val_loss: 43.8053 - val_vae_r_loss: 38.1347 - val_vae_kl_loss: 5.6706
Epoch 64/100
60000/60000 [==============================] - 7s 123us/sample - loss: 42.3011 - vae_r_loss: 36.5393 - vae_kl_loss: 5.7617 - val_loss: 44.0497 - val_vae_r_loss: 38.2698 - val_vae_kl_loss: 5.7799
Epoch 65/100
60000/60000 [==============================] - 8s 127us/sample - loss: 42.2317 - vae_r_loss: 36.4754 - vae_kl_loss: 5.7563 - val_loss: 43.8796 - val_vae_r_loss: 38.1148 - val_vae_kl_loss: 5.7648
Epoch 66/100
60000/60000 [==============================] - 8s 131us/sample - loss: 42.2031 - vae_r_loss: 36.4389 - vae_kl_loss: 5.7642 - val_loss: 44.0349 - val_vae_r_loss: 38.2374 - val_vae_kl_loss: 5.7974
Epoch 67/100
60000/60000 [==============================] - 8s 129us/sample - loss: 42.1963 - vae_r_loss: 36.4304 - vae_kl_loss: 5.7659 - val_loss: 44.2732 - val_vae_r_loss: 38.6935 - val_vae_kl_loss: 5.5797
Epoch 68/100
60000/60000 [==============================] - 8s 128us/sample - loss: 42.2105 - vae_r_loss: 36.4260 - vae_kl_loss: 5.7845 - val_loss: 43.9279 - val_vae_r_loss: 38.1766 - val_vae_kl_loss: 5.7512
Epoch 69/100
60000/60000 [==============================] - 7s 120us/sample - loss: 42.1248 - vae_r_loss: 36.3452 - vae_kl_loss: 5.7796 - val_loss: 43.7429 - val_vae_r_loss: 37.9045 - val_vae_kl_loss: 5.8385
Epoch 70/100
60000/60000 [==============================] - 8s 129us/sample - loss: 42.1134 - vae_r_loss: 36.3330 - vae_kl_loss: 5.7804 - val_loss: 43.9997 - val_vae_r_loss: 38.3955 - val_vae_kl_loss: 5.6042
Epoch 71/100
60000/60000 [==============================] - 8s 128us/sample - loss: 42.0834 - vae_r_loss: 36.2981 - vae_kl_loss: 5.7853 - val_loss: 43.9452 - val_vae_r_loss: 38.0412 - val_vae_kl_loss: 5.9041
Epoch 72/100
60000/60000 [==============================] - 7s 123us/sample - loss: 42.0814 - vae_r_loss: 36.2836 - vae_kl_loss: 5.7977 - val_loss: 43.9802 - val_vae_r_loss: 38.2024 - val_vae_kl_loss: 5.7778
Epoch 73/100
60000/60000 [==============================] - 8s 130us/sample - loss: 42.0609 - vae_r_loss: 36.2591 - vae_kl_loss: 5.8018 - val_loss: 43.9902 - val_vae_r_loss: 38.1622 - val_vae_kl_loss: 5.8281
Epoch 74/100
60000/60000 [==============================] - 8s 128us/sample - loss: 42.0404 - vae_r_loss: 36.2434 - vae_kl_loss: 5.7970 - val_loss: 43.8329 - val_vae_r_loss: 37.9803 - val_vae_kl_loss: 5.8526
Epoch 75/100
60000/60000 [==============================] - 8s 126us/sample - loss: 42.0296 - vae_r_loss: 36.2239 - vae_kl_loss: 5.8057 - val_loss: 43.8567 - val_vae_r_loss: 38.0927 - val_vae_kl_loss: 5.7640
Epoch 76/100
60000/60000 [==============================] - 8s 126us/sample - loss: 42.0071 - vae_r_loss: 36.2024 - vae_kl_loss: 5.8047 - val_loss: 43.9348 - val_vae_r_loss: 38.0641 - val_vae_kl_loss: 5.8707
Epoch 77/100
60000/60000 [==============================] - 8s 129us/sample - loss: 41.9923 - vae_r_loss: 36.1710 - vae_kl_loss: 5.8214 - val_loss: 43.7379 - val_vae_r_loss: 37.9295 - val_vae_kl_loss: 5.8084
Epoch 78/100
60000/60000 [==============================] - 8s 130us/sample - loss: 41.9557 - vae_r_loss: 36.1401 - vae_kl_loss: 5.8156 - val_loss: 43.9648 - val_vae_r_loss: 38.1244 - val_vae_kl_loss: 5.8404
Epoch 79/100
60000/60000 [==============================] - 8s 130us/sample - loss: 41.9122 - vae_r_loss: 36.0964 - vae_kl_loss: 5.8158 - val_loss: 43.9287 - val_vae_r_loss: 38.0709 - val_vae_kl_loss: 5.8578
Epoch 80/100
60000/60000 [==============================] - 8s 133us/sample - loss: 41.9308 - vae_r_loss: 36.0949 - vae_kl_loss: 5.8360 - val_loss: 43.7428 - val_vae_r_loss: 37.9610 - val_vae_kl_loss: 5.7818
Epoch 81/100
60000/60000 [==============================] - 8s 127us/sample - loss: 41.9012 - vae_r_loss: 36.0648 - vae_kl_loss: 5.8364 - val_loss: 44.3266 - val_vae_r_loss: 38.5170 - val_vae_kl_loss: 5.8096
Epoch 82/100
60000/60000 [==============================] - 7s 125us/sample - loss: 41.8750 - vae_r_loss: 36.0417 - vae_kl_loss: 5.8333 - val_loss: 43.7863 - val_vae_r_loss: 37.9801 - val_vae_kl_loss: 5.8062
Epoch 83/100
60000/60000 [==============================] - 7s 124us/sample - loss: 41.8738 - vae_r_loss: 36.0375 - vae_kl_loss: 5.8363 - val_loss: 43.7445 - val_vae_r_loss: 38.0025 - val_vae_kl_loss: 5.7420
Epoch 84/100
60000/60000 [==============================] - 7s 122us/sample - loss: 41.8827 - vae_r_loss: 36.0448 - vae_kl_loss: 5.8379 - val_loss: 43.7838 - val_vae_r_loss: 37.9531 - val_vae_kl_loss: 5.8307
Epoch 85/100
60000/60000 [==============================] - 8s 127us/sample - loss: 41.8511 - vae_r_loss: 36.0024 - vae_kl_loss: 5.8486 - val_loss: 43.8619 - val_vae_r_loss: 38.0344 - val_vae_kl_loss: 5.8275
Epoch 86/100
60000/60000 [==============================] - 8s 129us/sample - loss: 41.8079 - vae_r_loss: 35.9542 - vae_kl_loss: 5.8537 - val_loss: 43.7382 - val_vae_r_loss: 37.9344 - val_vae_kl_loss: 5.8038
Epoch 87/100
60000/60000 [==============================] - 7s 123us/sample - loss: 41.7851 - vae_r_loss: 35.9397 - vae_kl_loss: 5.8454 - val_loss: 44.1585 - val_vae_r_loss: 38.4855 - val_vae_kl_loss: 5.6730
Epoch 88/100
60000/60000 [==============================] - 7s 124us/sample - loss: 41.7767 - vae_r_loss: 35.9221 - vae_kl_loss: 5.8546 - val_loss: 43.9046 - val_vae_r_loss: 38.0805 - val_vae_kl_loss: 5.8241
Epoch 89/100
60000/60000 [==============================] - 8s 126us/sample - loss: 41.7489 - vae_r_loss: 35.8807 - vae_kl_loss: 5.8681 - val_loss: 43.8915 - val_vae_r_loss: 38.2089 - val_vae_kl_loss: 5.6826
Epoch 90/100
60000/60000 [==============================] - 8s 129us/sample - loss: 41.7509 - vae_r_loss: 35.8907 - vae_kl_loss: 5.8602 - val_loss: 44.1646 - val_vae_r_loss: 38.3839 - val_vae_kl_loss: 5.7807
Epoch 91/100
60000/60000 [==============================] - 8s 125us/sample - loss: 41.7379 - vae_r_loss: 35.8653 - vae_kl_loss: 5.8725 - val_loss: 43.8889 - val_vae_r_loss: 38.1448 - val_vae_kl_loss: 5.7441
Epoch 92/100
60000/60000 [==============================] - 8s 130us/sample - loss: 41.7136 - vae_r_loss: 35.8385 - vae_kl_loss: 5.8751 - val_loss: 43.8588 - val_vae_r_loss: 38.0302 - val_vae_kl_loss: 5.8286
Epoch 93/100
60000/60000 [==============================] - 7s 121us/sample - loss: 41.7043 - vae_r_loss: 35.8165 - vae_kl_loss: 5.8879 - val_loss: 43.8910 - val_vae_r_loss: 38.1920 - val_vae_kl_loss: 5.6990
Epoch 94/100
60000/60000 [==============================] - 8s 129us/sample - loss: 41.6996 - vae_r_loss: 35.8261 - vae_kl_loss: 5.8734 - val_loss: 43.7630 - val_vae_r_loss: 37.9765 - val_vae_kl_loss: 5.7865
Epoch 95/100
60000/60000 [==============================] - 8s 129us/sample - loss: 41.6560 - vae_r_loss: 35.7747 - vae_kl_loss: 5.8813 - val_loss: 44.0105 - val_vae_r_loss: 38.2086 - val_vae_kl_loss: 5.8019
Epoch 96/100
60000/60000 [==============================] - 8s 129us/sample - loss: 41.6414 - vae_r_loss: 35.7612 - vae_kl_loss: 5.8802 - val_loss: 44.0079 - val_vae_r_loss: 38.1036 - val_vae_kl_loss: 5.9043
Epoch 97/100
60000/60000 [==============================] - 7s 119us/sample - loss: 41.6342 - vae_r_loss: 35.7423 - vae_kl_loss: 5.8919 - val_loss: 43.8180 - val_vae_r_loss: 37.9197 - val_vae_kl_loss: 5.8984
Epoch 98/100
60000/60000 [==============================] - 8s 130us/sample - loss: 41.6329 - vae_r_loss: 35.7349 - vae_kl_loss: 5.8980 - val_loss: 43.7169 - val_vae_r_loss: 37.8978 - val_vae_kl_loss: 5.8190
Epoch 99/100
60000/60000 [==============================] - 7s 120us/sample - loss: 41.6222 - vae_r_loss: 35.7289 - vae_kl_loss: 5.8933 - val_loss: 43.8119 - val_vae_r_loss: 38.0215 - val_vae_kl_loss: 5.7904
Epoch 100/100
60000/60000 [==============================] - 8s 132us/sample - loss: 41.5922 - vae_r_loss: 35.7014 - vae_kl_loss: 5.8908 - val_loss: 43.6934 - val_vae_r_loss: 37.7484 - val_vae_kl_loss: 5.9449
Train duration : 767.34 sec. - 0:12:47
%% Cell type:code id: tags:
``` python
```
%% Cell type:code id: tags:
``` python
```
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
......@@ -3,7 +3,7 @@
# _____ _ _ _
# | ___(_) __| | | ___
# | |_ | |/ _` | |/ _ \
# | _| | | (_| | | __/ Data_generator
# | _| | | (_| | | __/ DataGenerator
# |_| |_|\__,_|_|\___| for clustered CelebA sataset
# ------------------------------------------------------------------
# Formation Introduction au Deep Learning (FIDLE)
......@@ -19,11 +19,11 @@ import os,glob
from tensorflow.keras.utils import Sequence
class Data_generator(Sequence):
class DataGenerator(Sequence):
version = 0.1
version = 0.4
def __init__(self, clusters_dir='./data', batch_size=32, debug=False):
def __init__(self, clusters_dir='./data', batch_size=32, debug=False, k_size=1):
'''
Instanciation of the data generator
args:
......@@ -31,6 +31,7 @@ class Data_generator(Sequence):
batch_size : Batch size (32)
debug : debug mode (False)
'''
if debug : self.about()
#
# ---- Get the list of clusters
#
......@@ -44,8 +45,13 @@ class Data_generator(Sequence):
for c in clusters_name:
df = pd.read_csv(c+'.csv', header=0)
dataset_size+=len(df.index)
#
# ---- If we only want to use a part of the dataset...
#
dataset_size = int(dataset_size * k_size)
#
if debug:
print(f'Clusters nb : {len(clusters_name)} files')
print(f'\nClusters nb : {len(clusters_name)} files')
print(f'Dataset size : {dataset_size}')
print(f'Batch size : {batch_size}')
......@@ -99,6 +105,11 @@ class Data_generator(Sequence):
return batch, batch
def on_epoch_end(self):
self.cluster_i = clusters_size
self.read_next_cluster()
def read_next_cluster(self):
#
# ---- Get the next cluster name
......@@ -123,10 +134,9 @@ class Data_generator(Sequence):
self.cluster_i = i
#
if self.debug: print(f'\n[Load {self.cluster_i:02d},s={len(self.data):3d}] ',end='')
@classmethod
def about(cls):
print('\nFIDLE 2020 - Data_generator')
print('\nFIDLE 2020 - DataGenerator')
print('Version :', cls.version)
\ No newline at end of file
......@@ -23,10 +23,9 @@ from tensorflow.keras.callbacks import ModelCheckpoint, TensorBoard
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.utils import plot_model
import tensorflow.keras.datasets.imdb as imdb
from modules.callbacks import ImagesCallback
from modules.data_generator import DataGenerator
import modules.callbacks
from modules.callbacks import ImagesCallback
import os, json, time, datetime
......@@ -163,20 +162,33 @@ class VariationalAutoencoder():
print(f'Optimizer is Adam with learning_rate={learning_rate:}')
def train(self,
x_train,x_test,
x_train=None,
x_test=None,
data_generator=None,
batch_size=32,
epochs=200,
image_periodicity=1,
chkpt_periodicity=2,
initial_epoch=0,
dataset_size=1
k_size=1
):
# ---- Dataset size
n_train = int(x_train.shape[0] * dataset_size)
n_test = int(x_test.shape[0] * dataset_size)
# ---- Data given or via generator
mode_data = (data_generator is None)
# ---- Size of the dataset we are going to use
# k_size ==1 : mean 100%
# ** Cannot be use with data generator **
#
if mode_data:
n_train = int(x_train.shape[0] * k_size)
n_test = int(x_test.shape[0] * k_size)
else:
n_train = len(data_generator)
n_test = int(x_test.shape[0])
# ---- Need by callbacks
self.n_train = n_train
self.n_test = n_test
......@@ -201,20 +213,38 @@ class VariationalAutoencoder():
# ---- Let's go...
start_time = time.time()
self.history = self.model.fit(x_train[:n_train], x_train[:n_train],
batch_size = batch_size,
shuffle = True,
epochs = epochs,
initial_epoch = initial_epoch,
callbacks = callbacks_list,
validation_data = (x_test[:n_test], x_test[:n_test])
)
if mode_data:
#
# ---- With pure data (x_train) -----------------------------------------
#
self.history = self.model.fit(x_train[:n_train], x_train[:n_train],
batch_size = batch_size,
shuffle = True,
epochs = epochs,
initial_epoch = initial_epoch,
callbacks = callbacks_list,
validation_data = (x_test[:n_test], x_test[:n_test])
)
#
else:
# ---- With Data Generator ----------------------------------------------
#
self.history = self.model.fit(data_generator,
shuffle = True,
epochs = epochs,
initial_epoch = initial_epoch,
callbacks = callbacks_list
# validation_data = (x_test, x_test)
)
end_time = time.time()
dt = end_time-start_time
dth = str(datetime.timedelta(seconds=int(dt)))
self.duration = dt
print(f'\nTrain duration : {dt:.2f} sec. - {dth:}')
def plot_model(self):
d=self.run_directory+'/figs'
plot_model(self.model, to_file=f'{d}/model.png', show_shapes = True, show_layer_names = True, expand_nested=True)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment