<img width="800px" src="../fidle/img/header.svg"></img>

# <!-- TITLE --> [LVAE1] - First VAE, using Lightning API (MNIST dataset)
<!-- DESC --> Construction and training of a VAE, using Lightning API, with a latent space of small dimension, using PyTorch Lightning

<!-- AUTHOR : Achille Mbogol Touye (EFIlIA-MIAI/SIMaP) -->

## Objectives :
 - Understanding and implementing a **variational autoencoder** neurals network (VAE)
 - Understanding **Ligthning API**, using two custom layers

The calculation needs being important, it is preferable to use a very simple dataset such as MNIST to start with. 
...MNIST with a small scale if you haven't a GPU ;-)

## What we're going to do :

 - Defining a VAE model
 - Build the model
 - Train it
 - Have a look on the train process


## Step 1 - Init python stuff

In [None]:
import os
import torch
import pandas as pd
import numpy as np
import torch.nn as nn
import lightning.pytorch as pl

from modules.datagen import MNIST
from torch.utils.data import TensorDataset, DataLoader
from modules.progressbar import CustomTrainProgressBar
from modules.callbacks import ImagesCallback, BestModelCallback
from modules.layers import SamplingLayer, VariationalLossLayer
from lightning.pytorch.loggers.tensorboard import TensorBoardLogger

import fidle

# Init Fidle environment
run_id, run_dir, datasets_dir = fidle.init('LVAE1')

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## Step 2 - Parameters
`scale` : With scale=1, we need 1'30s on a GPU V100 ...and >20' on a CPU !\
`latent_dim` : 2 dimensions is small, but usefull to draw !\
`fit_verbosity`: Verbosity of training progress bar: 0=silent, 1=progress bar, 2=One line 

`loss_weights` : Our **loss function** is the weighted sum of two loss:
 - `r_loss` which measures the loss during reconstruction. 
 - `kl_loss` which measures the dispersion. 

The weights are defined by: `loss_weights=[k1,k2]` where : `vae_loss = k1*r_loss + k2*kl_loss` 
In practice, a value of \[1,.001\] gives good results here.


In [None]:
latent_dim = 2
loss_weights = [1,.001]

scale = 0.2
seed = 123

batch_size = 64
epochs = 10
fit_verbosity = 1

Override parameters (batch mode) - Just forget this cell

In [None]:
fidle.override('latent_dim', 'loss_weights', 'scale', 'seed', 'batch_size', 'epochs', 'fit_verbosity')

## Step 3 - Prepare data
`MNIST.get_data()` return : `x_train,y_train, x_test,y_test`, \
but we only need x_train for our training.

In [None]:
x_data, y_data, _,_ = MNIST.get_data(seed=seed, scale=scale, train_prop=1 )

fidle.scrawler.images(x_data[:20], None, indices='all', columns=10, x_size=1,y_size=1,y_padding=0, save_as='01-original')

 ## 3.1 - For Training model use Dataloader
The Dataset retrieves our dataset’s features and labels one sample at a time. While training a model, we typically want to pass samples in minibatches, reshuffle the data at every epoch to reduce model overfitting. DataLoader is an iterable that abstracts this complexity for us in an easy API

In [None]:
train_dataset = TensorDataset(x_data,y_data)

# train bacth data
train_loader= DataLoader(
 dataset=train_dataset, 
 shuffle=False, 
 batch_size=batch_size, 
 num_workers=2 
)


## Step 4 - Build model
In this example, we will use the **pytorch ligthning API.** 
For this, we will use two custom layers :
 - `SamplingLayer`, which generates a vector z from the parameters z_mean and z_logvar - See : [SamplingLayer.py](./modules/layers/SamplingLayer.py)
 - `VariationalLossLayer`, which allows us to calculate the loss function, loss - See : [VariationalLossLayer.py](./modules/layers/VariationalLossLayer.py)

#### Encoder

In [None]:
class Encoder(nn.Module):
 def __init__(self, latent_dim):
 super().__init__()
 self.Convblock=nn.Sequential(
 nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3, stride=1, padding=1),
 nn.BatchNorm2d(32),
 nn.LeakyReLU(0.2),
 
 nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=2, padding=1),
 nn.BatchNorm2d(64),
 nn.LeakyReLU(0.2),
 
 nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=2, padding=1),
 nn.BatchNorm2d(64),
 nn.LeakyReLU(0.2),
 
 nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1),
 nn.BatchNorm2d(64),
 nn.LeakyReLU(0.2),
 
 nn.Flatten(),

 nn.Linear(64*7*7, 16),
 nn.BatchNorm1d(16),
 nn.LeakyReLU(0.2),
 )

 self.z_mean = nn.Linear(16, latent_dim)
 self.z_logvar = nn.Linear(16, latent_dim)
 


 def forward(self, x):
 x = self.Convblock(x)
 z_mean = self.z_mean(x)
 z_logvar = self.z_logvar(x) 
 z = SamplingLayer()([z_mean, z_logvar]) 
 
 return z_mean, z_logvar, z 

#### Decoder

In [None]:
class Decoder(nn.Module):
 def __init__(self, latent_dim):
 super().__init__()
 self.linear=nn.Sequential(
 nn.Linear(latent_dim, 16),
 nn.BatchNorm1d(16),
 nn.ReLU(),
 
 nn.Linear(16, 64*7*7),
 nn.BatchNorm1d(64*7*7),
 nn.ReLU()
 )
 
 self.Deconvblock=nn.Sequential(
 nn.ConvTranspose2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1),
 nn.BatchNorm2d(64),
 nn.ReLU(),
 
 nn.ConvTranspose2d(in_channels=64, out_channels=64, kernel_size=4, stride=2, padding=1),
 nn.BatchNorm2d(64),
 nn.ReLU(),
 
 nn.ConvTranspose2d(in_channels=64, out_channels=32, kernel_size=4, stride=2, padding=1),
 nn.BatchNorm2d(32),
 nn.ReLU(),
 
 nn.ConvTranspose2d(in_channels=32, out_channels=1, kernel_size=3, stride=1, padding=1),
 nn.Sigmoid()
 )
 


 def forward(self, z):
 x = self.linear(z)
 x = x.reshape(-1,64,7,7)
 x_hat = self.Deconvblock(x)
 return x_hat

#### VAE

We will calculate the loss with a specific layer: `VariationalLossLayer` - See : [VariationalLossLayer.py](./modules/layers/VariationalLossLayer.py)

In [None]:
class LitVAE(pl.LightningModule):
 
 def __init__(self, encoder, decoder):
 super().__init__()
 self.encoder = encoder
 self.decoder = decoder
 
 # forward pass
 def forward(self, x):
 z_mean, z_logvar, z = self.encoder(x)
 x_hat = self.decoder(z)
 return x_hat

 def training_step(self, batch, batch_idx):
 # training_step defines the train loop.
 x, _ = batch
 z_mean, z_logvar, z = self.encoder(x)
 x_hat = self.decoder(z)

 
 r_loss,kl_loss,loss = VariationalLossLayer(loss_weights=loss_weights)([x, z_mean,z_logvar,x_hat]) 

 metrics = {"r_loss" : r_loss, 
 "kl_loss" : kl_loss,
 "vae_loss" : loss
 }
 
 # logs metrics for each training_step
 self.log_dict(metrics,
 on_step = False,
 on_epoch = True, 
 prog_bar = True, 
 logger = True
 ) 
 
 
 return loss
 
 def configure_optimizers(self):
 optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
 return optimizer


In [None]:
# print model
vae=LitVAE(Encoder(latent_dim=2),Decoder(latent_dim=2))
print(vae)

## Step 5 - Train
### 5.1 - Using two nice custom callbacks :-)
Two custom callbacks are used:
 - `ImagesCallback` : sauvegardera des images durant l'apprentissage - See [ImagesCallback.py](./modules/callbacks/ImagesCallback.py)
 - `BestModelCallback` : qui sauvegardera le meilleur model - See [BestModelCallback.py](./modules/callbacks/BestModelCallback.py)

In [None]:
# save best model
save_dir = "./run/models/"
BestModelCallback = BestModelCallback(dirpath= save_dir) 
CallbackImages = ImagesCallback(x=x_data, z_dim=latent_dim, nb_images=5, from_z=True, from_random=True, run_dir=run_dir)
logger= TensorBoardLogger(save_dir='VAE1_logs',name="VAE_logs") # loggers data

### 5.2 - Let's train !
With `scale=1`, need 1'15 on a GPU (V100 at IDRIS) ...or 20' on a CPU 

In [None]:
chrono=fidle.Chrono()
chrono.start()

# train model
trainer= pl.Trainer(accelerator='auto',
 max_epochs=epochs,
 logger=logger,
 num_sanity_val_steps=0,
 callbacks=[CustomTrainProgressBar(), BestModelCallback, CallbackImages]
 )

trainer.fit(model=vae, train_dataloaders=train_loader)



chrono.show()

## Step 6 - Training review
### 6.1 - History

In [None]:
# launch Tensorboard 
%reload_ext tensorboard
%tensorboard --logdir=./VAE1_logs/VAE_logs/ --bind_all

### 6.2 - Reconstruction during training
At the end of each epoch, our callback saved some reconstructed images. 
Where : 
Original image -> encoder -> z -> decoder -> Reconstructed image

In [None]:
images_z, images_r = CallbackImages.get_images( range(0,epochs,2) )

fidle.utils.subtitle('Original images :')
fidle.scrawler.images(x_data[:5], None, indices='all', columns=5, x_size=2,y_size=2, save_as=None)

fidle.utils.subtitle('Encoded/decoded images')
fidle.scrawler.images(images_z, None, indices='all', columns=5, x_size=2,y_size=2, save_as='02-reconstruct')

fidle.utils.subtitle('Original images :')
fidle.scrawler.images(x_data[:5], None, indices='all', columns=5, x_size=2,y_size=2, save_as=None)


### 6.3 - Generation (latent -> decoder)

In [None]:
fidle.utils.subtitle('Generated images from latent space')
fidle.scrawler.images(images_r, None, indices='all', columns=5, x_size=2,y_size=2, save_as='03-generated')

## Annexe - Model Save and reload 

In [None]:
#---- Load the model from a checkpoint
loaded_model = LitVAE.load_from_checkpoint(BestModelCallback.best_model_path,
 encoder=Encoder(latent_dim=2),
 decoder=Decoder(latent_dim=2))
# put model in evaluation modecnrs
loaded_model.eval()

# ---- Retrieve a layer decoder
decoder=loaded_model.decoder

# example of z
z = torch.Tensor([[-1,.1]]).to(device)
img = decoder(z)

fidle.scrawler.images(img.cpu().detach(), x_size=2,y_size=2, save_as='04-example')

In [None]:
fidle.end()

---
<img width="80px" src="../fidle/img/logo-paysage.svg"></img>