Skip to content
Snippets Groups Projects
Commit e0c5d593 authored by Jean-Luc Parouty's avatar Jean-Luc Parouty
Browse files

Update VAE with celebA dataset

parent f28f4f37
No related branches found
No related tags found
1 merge request!5Update style in README
%% Cell type:markdown id: tags:
<img width="800px" src="../fidle/img/00-Fidle-header-01.svg"></img>
# <!-- TITLE --> [VAE6] - Variational AutoEncoder (VAE) with CelebA (small)
<!-- DESC --> Episode 6 : Variational AutoEncoder (VAE) with CelebA (small res.)
# <!-- TITLE --> [VAE8] - Variational AutoEncoder (VAE) with CelebA (small)
<!-- DESC --> Variational AutoEncoder (VAE) with CelebA (small res. 128x128)
<!-- AUTHOR : Jean-Luc Parouty (CNRS/SIMaP) -->
## Objectives :
- Build and train a VAE model with a large dataset in **small resolution(>70 GB)**
- Understanding a more advanced programming model with **data generator**
The [CelebFaces Attributes Dataset (CelebA)](http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html) contains about 200,000 images (202599,218,178,3).
## What we're going to do :
- Defining a VAE model
- Build the model
- Train it
- Follow the learning process with Tensorboard
%% Cell type:markdown id: tags:
## Step 1 - Setup environment
### 1.1 - Python stuff
%% Cell type:code id: tags:
``` python
import tensorflow as tf
import numpy as np
import os,sys
from importlib import reload
import modules.vae
import modules.data_generator
reload(modules.data_generator)
reload(modules.vae)
from modules.vae import VariationalAutoencoder
from modules.data_generator import DataGenerator
sys.path.append('..')
import fidle.pwk_ns as ooo
place, datasets_dir = ooo.init()
VariationalAutoencoder.about()
DataGenerator.about()
```
%% Output
FIDLE 2020 - Practical Work Module
Version : 0.57 DEV
Run time : Sunday 13 September 2020, 10:15:25
TensorFlow version : 2.2.0
Keras version : 2.3.0-tf
Current place : Fidle at IDRIS
Datasets dir : /gpfswork/rech/mlh/commun/datasets
Update keras cache : Done
FIDLE 2020 - Variational AutoEncoder (VAE)
TensorFlow version : 2.2.0
VAE version : 1.28
FIDLE 2020 - DataGenerator
Version : 0.4.1
%% Cell type:markdown id: tags:
### 1.2 - The good place
%% Cell type:code id: tags:
``` python
train_dir = f'{datasets_dir}/celeba/clusters-s.train'
test_dir = f'{datasets_dir}/celeba/clusters-s.test'
```
%% Cell type:markdown id: tags:
## Step 2 - DataGenerator and validation data
Ok, everything's perfect, now let's instantiate our generator for the entire dataset.
%% Cell type:code id: tags:
``` python
data_gen = DataGenerator(train_dir, 32, k_size=1)
x_test = np.load(f'{test_dir}/images-000.npy')
print(f'Data generator : {len(data_gen)} batchs of {data_gen.batch_size} images, or {data_gen.dataset_size} images')
print(f'x_test : {len(x_test)} images')
```
%% Output
Data generator : 6250 batchs of 32 images, or 200000 images
x_test : 2599 images
%% Cell type:markdown id: tags:
## Step 3 - Get VAE model
%% Cell type:code id: tags:
``` python
tag = f'CelebA.001-S.{os.getenv("SLURM_JOB_ID","unknown")}'
input_shape = (128, 128, 3)
z_dim = 200
verbose = 1
encoder= [ {'type':'Conv2D', 'filters':32, 'kernel_size':(3,3), 'strides':2, 'padding':'same', 'activation':'relu'},
{'type':'Dropout', 'rate':0.25},
{'type':'Conv2D', 'filters':64, 'kernel_size':(3,3), 'strides':2, 'padding':'same', 'activation':'relu'},
{'type':'Dropout', 'rate':0.25},
{'type':'Conv2D', 'filters':64, 'kernel_size':(3,3), 'strides':2, 'padding':'same', 'activation':'relu'},
{'type':'Dropout', 'rate':0.25},
{'type':'Conv2D', 'filters':64, 'kernel_size':(3,3), 'strides':2, 'padding':'same', 'activation':'relu'},
{'type':'Dropout', 'rate':0.25},
]
decoder= [ {'type':'Conv2DTranspose', 'filters':64, 'kernel_size':(3,3), 'strides':2, 'padding':'same', 'activation':'relu'},
{'type':'Dropout', 'rate':0.25},
{'type':'Conv2DTranspose', 'filters':64, 'kernel_size':(3,3), 'strides':2, 'padding':'same', 'activation':'relu'},
{'type':'Dropout', 'rate':0.25},
{'type':'Conv2DTranspose', 'filters':32, 'kernel_size':(3,3), 'strides':2, 'padding':'same', 'activation':'relu'},
{'type':'Dropout', 'rate':0.25},
{'type':'Conv2DTranspose', 'filters':3, 'kernel_size':(3,3), 'strides':2, 'padding':'same', 'activation':'sigmoid'}
]
vae = modules.vae.VariationalAutoencoder(input_shape = input_shape,
encoder_layers = encoder,
decoder_layers = decoder,
z_dim = z_dim,
verbose = verbose,
run_tag = tag)
vae.save(model=None)
```
%% Output
Model initialized.
Outputs will be in : ./run/CelebA.001-S.265973
---------- Encoder --------------------------------------------------
Model: "model_1"
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
encoder_input (InputLayer) [(None, 128, 128, 3) 0
__________________________________________________________________________________________________
conv2d (Conv2D) (None, 64, 64, 32) 896 encoder_input[0][0]
__________________________________________________________________________________________________
dropout (Dropout) (None, 64, 64, 32) 0 conv2d[0][0]
__________________________________________________________________________________________________
conv2d_1 (Conv2D) (None, 32, 32, 64) 18496 dropout[0][0]
__________________________________________________________________________________________________
dropout_1 (Dropout) (None, 32, 32, 64) 0 conv2d_1[0][0]
__________________________________________________________________________________________________
conv2d_2 (Conv2D) (None, 16, 16, 64) 36928 dropout_1[0][0]
__________________________________________________________________________________________________
dropout_2 (Dropout) (None, 16, 16, 64) 0 conv2d_2[0][0]
__________________________________________________________________________________________________
conv2d_3 (Conv2D) (None, 8, 8, 64) 36928 dropout_2[0][0]
__________________________________________________________________________________________________
dropout_3 (Dropout) (None, 8, 8, 64) 0 conv2d_3[0][0]
__________________________________________________________________________________________________
flatten (Flatten) (None, 4096) 0 dropout_3[0][0]
__________________________________________________________________________________________________
mu (Dense) (None, 200) 819400 flatten[0][0]
__________________________________________________________________________________________________
log_var (Dense) (None, 200) 819400 flatten[0][0]
__________________________________________________________________________________________________
encoder_output (Lambda) (None, 200) 0 mu[0][0]
log_var[0][0]
==================================================================================================
Total params: 1,732,048
Trainable params: 1,732,048
Non-trainable params: 0
__________________________________________________________________________________________________
---------- Encoder --------------------------------------------------
Model: "model_2"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
decoder_input (InputLayer) [(None, 200)] 0
_________________________________________________________________
dense (Dense) (None, 4096) 823296
_________________________________________________________________
reshape (Reshape) (None, 8, 8, 64) 0
_________________________________________________________________
conv2d_transpose (Conv2DTran (None, 16, 16, 64) 36928
_________________________________________________________________
dropout_4 (Dropout) (None, 16, 16, 64) 0
_________________________________________________________________
conv2d_transpose_1 (Conv2DTr (None, 32, 32, 64) 36928
_________________________________________________________________
dropout_5 (Dropout) (None, 32, 32, 64) 0
_________________________________________________________________
conv2d_transpose_2 (Conv2DTr (None, 64, 64, 32) 18464
_________________________________________________________________
dropout_6 (Dropout) (None, 64, 64, 32) 0
_________________________________________________________________
conv2d_transpose_3 (Conv2DTr (None, 128, 128, 3) 867
=================================================================
Total params: 916,483
Trainable params: 916,483
Non-trainable params: 0
_________________________________________________________________
Failed to import pydot. You must install pydot and graphviz for `pydotprint` to work.
Failed to import pydot. You must install pydot and graphviz for `pydotprint` to work.
Failed to import pydot. You must install pydot and graphviz for `pydotprint` to work.
Config saved in : ./run/CelebA.001-S.265973/models/vae_config.json
%% Cell type:markdown id: tags:
## Step 4 - Compile it
%% Cell type:code id: tags:
``` python
optimizer = tf.keras.optimizers.Adam(1e-4)
# optimizer = 'adam'
r_loss_factor = 10000
vae.compile(optimizer, r_loss_factor)
```
%% Output
Compiled.
%% Cell type:markdown id: tags:
## Step 5 - Train
For 10 epochs, adam optimizer :
- Run time at IDRIS : 1299.77 sec. - 0:21:39
- Run time at GRICAD : 2092.77 sec. - 0:34:52
%% Cell type:code id: tags:
``` python
epochs = 10
initial_epoch = 0
```
%% Cell type:code id: tags:
``` python
vae.train(data_generator = data_gen,
x_test = x_test,
epochs = epochs,
initial_epoch = initial_epoch
)
```
%% Cell type:markdown id: tags:
---
<img width="80px" src="../fidle/img/00-Fidle-logo-01.svg"></img>
......
This diff is collapsed.
......@@ -33,14 +33,21 @@ class Sampling(keras.layers.Layer):
class VAE(keras.Model):
'''A VAE model, built from given encoder, decoder'''
def __init__(self, encoder=None, decoder=None, r_loss_factor=0.3, **kwargs):
def __init__(self, encoder=None, decoder=None, r_loss_factor=0.3, image_size=(28,28), **kwargs):
super(VAE, self).__init__(**kwargs)
self.encoder = encoder
self.decoder = decoder
self.r_loss_factor = r_loss_factor
print('Init VAE, with r_loss_factor=',self.r_loss_factor)
self.nb_pixels = np.prod(image_size)
print(f'Init VAE, with r_loss_factor={self.r_loss_factor} and image_size={image_size}')
def call(self, inputs):
z = self.encoder(inputs)
y_pred = self.decoder(z)
return y_pred
def train_step(self, data):
if isinstance(data, tuple):
......@@ -50,7 +57,7 @@ class VAE(keras.Model):
z_mean, z_log_var, z = self.encoder(data)
reconstruction = self.decoder(z)
reconstruction_loss = tf.reduce_mean( keras.losses.binary_crossentropy(data, reconstruction) )
reconstruction_loss *= 28*28
reconstruction_loss *= self.nb_pixels
kl_loss = 1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var)
kl_loss = tf.reduce_mean(kl_loss)
kl_loss *= -0.5
......
......@@ -67,9 +67,13 @@ class DataGenerator(Sequence):
#
# ---- Read a first cluster
#
self.cluster_i = clusters_size
self.rewind()
def rewind(self):
self.cluster_i = self.clusters_size
self.read_next_cluster()
def __len__(self):
return math.floor(self.dataset_size / self.batch_size)
......@@ -106,8 +110,7 @@ class DataGenerator(Sequence):
def on_epoch_end(self):
self.cluster_i = self.clusters_size
self.read_next_cluster()
self.rewind()
def read_next_cluster(self):
......
......@@ -66,3 +66,8 @@ VAE7_smart_image_size = (128,128)
VAE7_smart_enhanced_dir = './data'
VAE7_full_image_size = (192,160)
VAE7_full_enhanced_dir = '{datasets_dir}/celeba/enhanced'
VAE8_smart_image_size = (128,128)
VAE8_smart_enhanced_dir = './data'
VAE8_full_image_size = (192,160)
VAE8_full_enhanced_dir = '{datasets_dir}/celeba/enhanced'
......@@ -78,10 +78,10 @@
"duration": "00:00:03 329ms"
},
"VAE1": {
"path": "/gpfsdswork/projects/rech/mlh/uja62cb/fidle/VAE",
"start": "Tuesday 29 December 2020, 18:03:41",
"end": "Tuesday 29 December 2020, 18:07:02",
"duration": "00:03:22 697ms"
"path": "/home/pjluc/dev/fidle/VAE",
"start": "Monday 4 January 2021, 17:59:08",
"end": "",
"duration": "Unfinished..."
},
"MNIST1": {
"path": "/home/pjluc/dev/fidle/MNIST",
......@@ -148,5 +148,11 @@
"start": "Saturday 2 January 2021, 17:28:39",
"end": "Saturday 2 January 2021, 17:28:47",
"duration": "00:00:08 736ms"
},
"VAE8": {
"path": "/home/pjluc/dev/fidle/VAE",
"start": "Monday 4 January 2021, 18:43:15",
"end": "Monday 4 January 2021, 18:56:01",
"duration": "00:12:46 153ms"
}
}
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment