<img width="800px" src="../fidle/img/00-Fidle-header-01.svg"></img>

# <!-- TITLE --> [VAE9] - Training session for our VAE with 192x160 images
<!-- DESC --> Episode 4 : Training with our clustered datasets in notebook or batch mode
<!-- AUTHOR : Jean-Luc Parouty (CNRS/SIMaP) -->

## Objectives :
 - Build and train a VAE model with a large dataset in  **medium resolution 140 GB**
 - Understanding a more advanced programming model with **data generator**

The [CelebFaces Attributes Dataset (CelebA)](http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html) contains about 200,000 images (202599,218,178,3).  

## What we're going to do :

 - Defining a VAE model
 - Build the model
 - Train it
 - Follow the learning process with Tensorboard

## Acknowledgements :
As before, thanks to **FranÃ§ois Chollet** who is at the base of this example.  
See : https://keras.io/examples/generative/vae


## Step 1 - Init python stuff

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import sys

from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.callbacks import TensorBoard

from modules.models    import VAE
from modules.layers    import SamplingLayer
from modules.callbacks import ImagesCallback, BestModelCallback
from modules.datagen   import DataGenerator

import fidle

# Init Fidle environment
run_id, run_dir, datasets_dir = fidle.init('VAE9')

VAE.about()
DataGenerator.about()

In [None]:
# To clean run_dir, uncomment and run this next line
# ! rm -r "$run_dir"/images-* "$run_dir"/logs "$run_dir"/figs "$run_dir"/models ; rmdir "$run_dir"

## Step 2 - Parameters
`scale` : With scale=1, we need 1'30s on a GPU V100 ...and >20' on a CPU !  
`latent_dim` : 2 dimensions is small, but usefull to draw !  
`fit_verbosity`: Verbosity of training progress bar: 0=silent, 1=progress bar, 2=One line  

`loss_weights` : Our **loss function** is the weighted sum of two loss:
 - `r_loss` which measures the loss during reconstruction.  
 - `kl_loss` which measures the dispersion.  

The weights are defined by: `loss_weights=[k1,k2]` where : `total_loss = k1*r_loss + k2*kl_loss`  
In practice, a value of \[.6,.4\] gives good results here.


Uncomment the right lines according to what you want.

In [1]:
fit_verbosity = 1

# ---- For tests

scale         = 0.01
image_size    = (192,160)
enhanced_dir  = './data'
latent_dim    = 300
loss_weights  = [.6,.4]
batch_size    = 64
epochs        = 5

# ---- Training with a full dataset of large images
#
# scale         = 1.
# image_size    = (192,160)
# enhanced_dir  = f'{datasets_dir}/celeba/enhanced'
# latent_dim    = 300
# loss_weights  = [.6,.4]
# batch_size    = 64
# epochs        = 15

Override parameters (batch mode) - Just forget this cell

In [None]:
fidle.override('scale', 'image_size', 'enhanced_dir', 'latent_dim', 'loss_weights')
fidle.override('batch_size', 'epochs', 'fit_verbosity')

## Step 3 - Prepare data
Let's instantiate our generator for the entire dataset.

### 3.1 - Finding the right place

In [None]:
lx,ly      = image_size
train_dir  = f'{enhanced_dir}/clusters-{lx}x{ly}'

print('Train directory is :',train_dir)

### 3.2 - Get a DataGenerator

In [None]:
data_gen = DataGenerator(train_dir, 32, scale=scale)

print(f'Data generator is ready with : {len(data_gen)} batchs of {data_gen.batch_size} images, or {data_gen.dataset_size} images')

## Step 4 - Build model
Note: We conserve the geometry of our last convolutional output (shape_before_flattening) so that we can adapt the decoder to the encoder.

#### Encoder

In [None]:
inputs    = keras.Input(shape=(lx, ly, 3))
x         = layers.Conv2D(32,  4, strides=2, padding="same", activation="relu")(inputs)
x         = layers.BatchNormalization(axis=1)(x)

x         = layers.Conv2D(64,  4, strides=2, padding="same", activation="relu")(x)
x         = layers.BatchNormalization(axis=1)(x)

x         = layers.Conv2D(128, 4, strides=2, padding="same", activation="relu")(x)
x         = layers.BatchNormalization(axis=1)(x)

x         = layers.Conv2D(256, 4, strides=2, padding="same", activation="relu")(x)
x         = layers.BatchNormalization(axis=1)(x)

x         = layers.Conv2D(512, 4, strides=2, padding="same", activation="relu")(x)
x         = layers.BatchNormalization(axis=1)(x)

x         = layers.Flatten()(x)

z_mean    = layers.Dense(latent_dim, name="z_mean")(x)
z_log_var = layers.Dense(latent_dim, name="z_log_var")(x)
z         = SamplingLayer()([z_mean, z_log_var])

encoder = keras.Model(inputs, [z_mean, z_log_var, z], name="encoder")
encoder.compile()
# encoder.summary()

#### Decoder

In [None]:
inputs  = keras.Input(shape=(latent_dim,))

x       = layers.Dense(512*6*5)(inputs)
x       = layers.Reshape((6,5,512))(x)

x       = layers.UpSampling2D()(x)
x       = layers.Conv2D(512,  kernel_size=3, strides=1, padding='same', activation='relu')(x)
x       = layers.BatchNormalization(axis=1)(x)

x       = layers.UpSampling2D()(x)
x       = layers.Conv2D(256,  kernel_size=3, strides=1, padding='same', activation='relu')(x)
x       = layers.BatchNormalization(axis=1)(x)

x       = layers.UpSampling2D()(x)
x       = layers.Conv2D(128,  kernel_size=3, strides=1, padding='same', activation='relu')(x)
x       = layers.BatchNormalization(axis=1)(x)

x       = layers.UpSampling2D()(x)
x       = layers.Conv2D(64,   kernel_size=3, strides=1, padding='same', activation='relu')(x)
x       = layers.BatchNormalization(axis=1)(x)

x       = layers.UpSampling2D()(x)
outputs = layers.Conv2D(3,    kernel_size=3, strides=1, padding='same', activation='sigmoid')(x)

decoder = keras.Model(inputs, outputs, name="decoder")
decoder.compile()
# decoder.summary()

#### VAE
Our loss function is the weighted sum of two values.  
`reconstruction_loss` which measures the loss during reconstruction.  
`kl_loss` which measures the dispersion.  

The weights are defined by: `r_loss_factor` :  
`total_loss = r_loss_factor*reconstruction_loss + (1-r_loss_factor)*kl_loss`

if `r_loss_factor = 1`, the loss function includes only `reconstruction_loss`  
if `r_loss_factor = 0`, the loss function includes only `kl_loss`  
In practice, a value arround 0.5 gives good results here.


In [None]:
vae = VAE(encoder, decoder, loss_weights)

vae.compile(optimizer=keras.optimizers.Adam())

## Step 5 - Train
With `scale=1`, need 20' for 10 epochs on a V100 (IDRIS)  
...on a basic CPU, may be >40 hours !

### 5.1 - Callbacks

In [None]:
x_draw,_   = data_gen[0]
data_gen.rewind()

callback_images      = ImagesCallback(x=x_draw, z_dim=latent_dim, nb_images=5, from_z=True, from_random=True, run_dir=run_dir)
callback_bestmodel   = BestModelCallback( run_dir + '/models/best_model.h5' )
callback_tensorboard = TensorBoard(log_dir=run_dir + '/logs', histogram_freq=1)

callbacks_list = [callback_images, callback_bestmodel]

### 5.2 - Train it

In [None]:
chrono = fidle.Chrono()
chrono.start()

history = vae.fit(data_gen, epochs=epochs, batch_size=batch_size, callbacks=callbacks_list, verbose=fit_verbosity)

chrono.show()

## Step 6 - Training review
### 6.1 - History

In [None]:
fidle.scrawler.history(history,  plot={"Loss":['loss','r_loss', 'kl_loss']}, save_as='01-history')

### 6.2 - Reconstruction during training

In [None]:
images_z, images_r = callback_images.get_images( range(0,epochs,2) )

fidle.utils.subtitle('Original images :')
fidle.scrawler.images(x_draw[:5], None, indices='all', columns=5, x_size=2,y_size=2, save_as='02-original')

fidle.utils.subtitle('Encoded/decoded images')
fidle.scrawler.images(images_z, None, indices='all', columns=5, x_size=2,y_size=2, save_as='03-reconstruct')

fidle.utils.subtitle('Original images :')
fidle.scrawler.images(x_draw[:5], None, indices='all', columns=5, x_size=2,y_size=2, save_as=None)


### 6.3 - Generation (latent -> decoder) during training

In [None]:
fidle.utils.subtitle('Generated images from latent space')
fidle.scrawler.images(images_r, None, indices='all', columns=5, x_size=2,y_size=2, save_as='04-encoded')

In [None]:
fidle.end()

---
<img width="80px" src="../fidle/img/00-Fidle-logo-01.svg"></img>