Skip to content
Snippets Groups Projects
Commit 51dc516e authored by Jean-Luc Parouty's avatar Jean-Luc Parouty
Browse files

Update VAE8

parent 8824c0f2
No related branches found
No related tags found
1 merge request!5Update style in README
%% Cell type:markdown id: tags:
<img width="800px" src="../fidle/img/00-Fidle-header-01.svg"></img>
# <!-- TITLE --> [VAE8] - Variational AutoEncoder (VAE) with CelebA (small)
<!-- DESC --> Variational AutoEncoder (VAE) with CelebA (small res. 128x128)
<!-- AUTHOR : Jean-Luc Parouty (CNRS/SIMaP) -->
## Objectives :
- Build and train a VAE model with a large dataset in **small resolution(>70 GB)**
- Understanding a more advanced programming model with **data generator**
The [CelebFaces Attributes Dataset (CelebA)](http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html) contains about 200,000 images (202599,218,178,3).
## What we're going to do :
- Defining a VAE model
- Build the model
- Train it
- Follow the learning process with Tensorboard
## Acknowledgements :
As before, thanks to **François Chollet** who is at the base of this example.
See : https://keras.io/examples/generative/vae
%% Cell type:markdown id: tags:
## Step 1 - Init python stuff
%% Cell type:code id: tags:
``` python
import numpy as np
from skimage import io
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.callbacks import ModelCheckpoint, TensorBoard
import os,sys,json,time,datetime
from IPython.display import display,Image,Markdown,HTML
from modules.data_generator import DataGenerator
from modules.VAE import VAE, Sampling
from modules.callbacks import ImagesCallback, BestModelCallback
sys.path.append('..')
import fidle.pwk as pwk
run_dir = './run/CelebA.001' # Output directory
datasets_dir = pwk.init('VAE8', run_dir)
VAE.about()
DataGenerator.about()
```
%% Output
Override : Attribute [run_dir=./run/CelebA.001] with [./run/test-VAE8-3370]
**FIDLE 2020 - Practical Work Module**
Version : 0.6.1 DEV
Notebook id : VAE8
Run time : Wednesday 6 January 2021, 19:47:34
TensorFlow version : 2.2.0
Keras version : 2.3.0-tf
Datasets dir : /home/pjluc/datasets/fidle
Run dir : ./run/test-VAE8-3370
Update keras cache : False
Save figs : True
Path figs : ./run/test-VAE8-3370/figs
<br>**FIDLE 2021 - VAE**
Version : 1.2
TensorFlow version : 2.2.0
Keras version : 2.3.0-tf
<br>**FIDLE 2020 - DataGenerator**
Version : 0.4.1
TensorFlow version : 2.2.0
Keras version : 2.3.0-tf
%% Cell type:code id: tags:
``` python
# To clean run_dir, uncomment and run this next line
# ! rm -r "$run_dir"/images-* "$run_dir"/logs "$run_dir"/figs "$run_dir"/models ; rmdir "$run_dir"
```
%% Cell type:markdown id: tags:
## Step 2 - Get some data
Let's instantiate our generator for the entire dataset.
%% Cell type:markdown id: tags:
### 1.1 - Parameters
Uncomment the right lines according to the data you want to use
%% Cell type:code id: tags:
``` python
# ---- For tests
scale = 0.3
image_size = (128,128)
enhanced_dir = './data'
latent_dim = 300
r_loss_factor = 0.6
batch_size = 64
epochs = 15
# ---- Training with a full dataset
# scale = 1.
# image_size = (128,128)
# enhanced_dir = f'{datasets_dir}/celeba/enhanced'
# latent_dim = 300
# r_loss_factor = 0.6
# batch_size = 64
# epochs = 15
# ---- Training with a full dataset of large images
# scale = 1.
# image_size = (192,160)
# enhanced_dir = f'{datasets_dir}/celeba/enhanced'
# latent_dim = 300
# r_loss_factor = 0.6
# batch_size = 64
# epochs = 15
```
%% Cell type:markdown id: tags:
### 1.2 - Finding the right place
%% Cell type:code id: tags:
``` python
# ---- Override parameters (batch mode) - Just forget this line
#
pwk.override('scale', 'image_size', 'enhanced_dir', 'latent_dim', 'r_loss_factor')
pwk.override('scale', 'image_size', 'enhanced_dir', 'latent_dim', 'r_loss_factor', 'batch_size', 'epochs')
# ---- the place of the clusters files
#
lx,ly = image_size
train_dir = f'{enhanced_dir}/clusters-{lx}x{ly}'
print('Train directory is :',train_dir)
```
%% Output
Train directory is : ./data/clusters-128x128
%% Cell type:markdown id: tags:
### 1.2 - Get a DataGenerator
%% Cell type:code id: tags:
``` python
data_gen = DataGenerator(train_dir, 32, k_size=scale)
print(f'Data generator is ready with : {len(data_gen)} batchs of {data_gen.batch_size} images, or {data_gen.dataset_size} images')
```
%% Output
Data generator is ready with : 379 batchs of 32 images, or 12155 images
%% Cell type:markdown id: tags:
## Step 3 - Build model
%% Cell type:markdown id: tags:
#### Encoder
%% Cell type:code id: tags:
``` python
inputs = keras.Input(shape=(lx, ly, 3))
x = layers.Conv2D(32, 3, strides=2, padding="same", activation="relu")(inputs)
x = layers.Conv2D(64, 3, strides=2, padding="same", activation="relu")(x)
x = layers.Conv2D(64, 3, strides=2, padding="same", activation="relu")(x)
x = layers.Conv2D(64, 3, strides=2, padding="same", activation="relu")(x)
shape_before_flattening = keras.backend.int_shape(x)[1:]
x = layers.Flatten()(x)
x = layers.Dense(512, activation="relu")(x)
z_mean = layers.Dense(latent_dim, name="z_mean")(x)
z_log_var = layers.Dense(latent_dim, name="z_log_var")(x)
z = Sampling()([z_mean, z_log_var])
encoder = keras.Model(inputs, [z_mean, z_log_var, z], name="encoder")
encoder.compile()
# encoder.summary()
```
%% Cell type:markdown id: tags:
#### Decoder
%% Cell type:code id: tags:
``` python
inputs = keras.Input(shape=(latent_dim,))
x = layers.Dense(np.prod(shape_before_flattening))(inputs)
x = layers.Reshape(shape_before_flattening)(x)
x = layers.Conv2DTranspose(64, 3, strides=2, padding="same", activation="relu")(x)
x = layers.Conv2DTranspose(64, 3, strides=2, padding="same", activation="relu")(x)
x = layers.Conv2DTranspose(64, 3, strides=2, padding="same", activation="relu")(x)
x = layers.Conv2DTranspose(32, 3, strides=2, padding="same", activation="relu")(x)
outputs = layers.Conv2DTranspose(3, 3, padding="same", activation="sigmoid")(x)
decoder = keras.Model(inputs, outputs, name="decoder")
decoder.compile()
# decoder.summary()
```
%% Cell type:markdown id: tags:
#### VAE
Our loss function is the weighted sum of two values.
`reconstruction_loss` which measures the loss during reconstruction.
`kl_loss` which measures the dispersion.
The weights are defined by: `r_loss_factor` :
`total_loss = r_loss_factor*reconstruction_loss + (1-r_loss_factor)*kl_loss`
if `r_loss_factor = 1`, the loss function includes only `reconstruction_loss`
if `r_loss_factor = 0`, the loss function includes only `kl_loss`
In practice, a value arround 0.5 gives good results here.
%% Cell type:code id: tags:
``` python
vae = VAE(encoder, decoder, r_loss_factor)
vae.compile(optimizer=keras.optimizers.Adam())
```
%% Cell type:markdown id: tags:
## Step 4 - Train
20' on a CPU
1'12 on a GPU (V100, IDRIS)
%% Cell type:markdown id: tags:
### 4.1 - Callbacks
%% Cell type:code id: tags:
``` python
x_draw,_ = data_gen[0]
data_gen.rewind()
# ---- Callback : Images encoded
pwk.mkdir(run_dir + '/images-encoded')
filename = run_dir + '/images-encoded/image-{epoch:03d}-{i:02d}.jpg'
callback_images1 = ImagesCallback(filename, x=x_draw[:5], encoder=encoder,decoder=decoder)
# ---- Callback : Images generated
pwk.mkdir(run_dir + '/images-generated')
filename = run_dir + '/images-generated/image-{epoch:03d}-{i:02d}.jpg'
callback_images2 = ImagesCallback(filename, x=None, nb_images=5, z_dim=latent_dim, encoder=encoder,decoder=decoder)
# ---- Callback : Best model
pwk.mkdir(run_dir + '/models')
filename = run_dir + '/models/best_model'
callback_bestmodel = BestModelCallback(filename)
# ---- Callback tensorboard
dirname = run_dir + '/logs'
callback_tensorboard = TensorBoard(log_dir=dirname, histogram_freq=1)
callbacks_list = [callback_images1, callback_images2, callback_bestmodel, callback_tensorboard]
callbacks_list = [callback_images1, callback_images2, callback_bestmodel]
```
%% Cell type:markdown id: tags:
### 4.2 - Train it
%% Cell type:code id: tags:
``` python
batch_size = 64
epochs = 15
pwk.chrono_start()
history = vae.fit(data_gen, epochs=epochs, batch_size=batch_size, callbacks=callbacks_list)
pwk.chrono_show()
```
%% Cell type:markdown id: tags:
## Step 5 - About our training session
### 5.1 - History
%% Cell type:code id: tags:
``` python
pwk.plot_history(history, plot={"Loss":['loss','r_loss', 'kl_loss']}, save_as='01-history')
```
%% Cell type:markdown id: tags:
### 5.2 - Reconstruction (input -> encoder -> decoder)
%% Cell type:code id: tags:
``` python
imgs=[]
labels=[]
for epoch in range(1,epochs,1):
for i in range(5):
filename = f'{run_dir}/images-encoded/image-{epoch:03d}-{i:02d}.jpg'.format(epoch=epoch, i=i)
img = io.imread(filename)
imgs.append(img)
pwk.subtitle('Original images :')
pwk.plot_images(x_draw[:5], None, indices='all', columns=5, x_size=2,y_size=2, save_as='02-original')
pwk.subtitle('Encoded/decoded images')
pwk.plot_images(imgs, None, indices='all', columns=5, x_size=2,y_size=2, save_as='03-reconstruct')
pwk.subtitle('Original images :')
pwk.plot_images(x_draw[:5], None, indices='all', columns=5, x_size=2,y_size=2, save_as=None)
```
%% Cell type:markdown id: tags:
### 5.3 Generation (latent -> decoder)
%% Cell type:code id: tags:
``` python
imgs=[]
labels=[]
for epoch in range(1,epochs,1):
for i in range(5):
filename = f'{run_dir}/images-generated/image-{epoch:03d}-{i:02d}.jpg'.format(epoch=epoch, i=i)
img = io.imread(filename)
imgs.append(img)
pwk.subtitle('Generated images from latent space')
pwk.plot_images(imgs, None, indices='all', columns=5, x_size=2,y_size=2, save_as='04-encoded')
```
%% Cell type:code id: tags:
``` python
pwk.end()
```
%% Cell type:markdown id: tags:
---
<img width="80px" src="../fidle/img/00-Fidle-logo-01.svg"></img>
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment