<img width="800px" src="../fidle/img/header.svg"></img>

# <!-- TITLE --> [LVAE3] - Analysis of the VAE's latent space of MNIST dataset
<!-- DESC --> Visualization and analysis of the VAE's latent space of the dataset MNIST, using PyTorch Lightning
<!-- AUTHOR : Jean-Luc Parouty (CNRS/SIMaP) -->

## Objectives :
 - First data generation from **latent space** 
 - Understanding of underlying principles
 - Model management

Here, we don't consume data anymore, but we generate them ! ;-)

## What we're going to do :

 - Load a saved model
 - Reconstruct some images
 - Latent space visualization
 - Matrix of generated images


## Step 1 - Init python stuff

### 1.1 - Init python

In [None]:
import os
import sys
import torch
import pandas as pd
import numpy  as np
import torch.nn as nn

from modules.callbacks import ImagesCallback, BestModelCallback
from modules.datagen   import MNIST
from modules.models    import Encoder, Decoder, VAE 


import scipy.stats
import matplotlib
import matplotlib.pyplot as plt
from barviz import Simplex
from barviz import Collection


import fidle

# Init Fidle environment
run_id, run_dir, datasets_dir = fidle.init('LVAE3')

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

### 1.2 - Parameters

In [None]:
scale      = 1
seed       = 123
models_dir = './run/models_dir/best-model-epoch=4-loss=0.00.ckpt'

Override parameters (batch mode) - Just forget this cell

In [None]:
fidle.override('scale', 'seed', 'models_dir')

## Step 2 - Get data

In [None]:
x_data, y_data, _,_ = MNIST.get_data(seed=seed, scale=scale, train_prop=1 )

## Step 3 - Reload best model

In [None]:
#---- Load the model from a checkpoint
latent_dim=6

vae = VAE.load_from_checkpoint(models_dir,
                               encoder=Encoder(latent_dim=latent_dim),
                               decoder=Decoder(latent_dim=latent_dim)
                              )
# put model in evaluation mode
vae.eval()

## Step 4 - Image reconstruction

In [None]:
# ---- Select few images

x_show = fidle.utils.pick_dataset(x_data, n=10)

# ---- Get latent points and reconstructed images

z_mean, z_var, z  = vae.encoder(x_show.to(device))
x_reconst         = vae.decoder(z)

latent_dim        = z.shape[1]

# ---- Show it
z         = z.cpu().detach()         # Move the tensor to CPU and detach it
x_reconst = x_reconst.cpu().detach()

labels=[ str(np.round(z[i],1)) for i in range(10) ]
fidle.utils.subtitle('Originals :')
fidle.scrawler.images(x_show,    None, indices='all', columns=10, x_size=2,y_size=2, save_as='01-original')
fidle.utils.subtitle('Reconstructed :')
fidle.scrawler.images(x_reconst, None, indices='all', columns=10, x_size=2,y_size=2, save_as='02-reconstruct')


## Step 5 - Visualizing the latent space

In [None]:
n_show = 5000

# ---- Select images

x_show, y_show   = fidle.utils.pick_dataset(x_data,y_data, n=n_show)

# ---- Get latent points

z_mean, z_var, z = vae.encoder(x_show.to(device))


### 5.1 - Classic 2d visualisaton

In [None]:
z   = z.cpu().detach()
fig = plt.figure(figsize=(14, 10))
plt.scatter(z[:, 0] , z[:, 1], c=y_show, cmap= 'tab10', alpha=0.5, s=30)
plt.colorbar()
fidle.scrawler.save_fig('03-Latent-space')
plt.show()

### 5.2 - Simplex visualisaton

In [None]:
if latent_dim<4:

    print('Sorry, This part can only work if the latent space is greater than 3')

else:

    # ---- Softmax rescale
    #
    zs = torch.exp(z)/torch.sum(torch.exp(z),axis=1,keepdims=True)
    zs=zs.cpu().detach()
    # zc  = zs * 1/np.max(zs)

    # ---- Create collection
    #
    c = Collection(zs, colors=y_show, labels=y_show)
    c.attrs.markers_colormap     = {'colorscale':'Rainbow','cmin':0,'cmax':latent_dim}
    c.attrs.markers_size         = 5
    c.attrs.markers_border_width = 0
    c.attrs.markers_opacity      = 0.8

    s = Simplex.build(latent_dim)
    s.attrs.width  = 1000
    s.attrs.height = 1000
    s.plot(c)

## Step 6 - Generate from latent space (latent_dim==2)

In [None]:
if latent_dim>2:

    print('Sorry, This part can only work if the latent space is of dimension 2')

else:

    grid_size   = 14
    grid_scale  = 1.

    # ---- Draw a ppf grid

    grid=[]
    for y in scipy.stats.norm.ppf(np.linspace(0.99, 0.01, grid_size),scale=grid_scale):
        for x in scipy.stats.norm.ppf(np.linspace(0.01, 0.99, grid_size),scale=grid_scale):
            grid.append( (x,y) )
    grid=np.array(grid)

    # ---- Draw latentspoints and grid

    fig = plt.figure(figsize=(12, 10))
    plt.scatter(z[:, 0] , z[:, 1], c=y_show, cmap= 'tab10', alpha=0.5, s=20)
    plt.scatter(grid[:, 0] , grid[:, 1], c = 'black', s=60, linewidth=2, marker='+', alpha=1)
    fidle.scrawler.save_fig('04-Latent-grid')
    plt.show()

    # ---- Plot grid corresponding images
    grid      = torch.from_numpy(grid).to(device)
    x_reconst = vae.decoder([grid])
    x_reconst = x_reconst.cpu().detach()
    fidle.scrawler.images(x_reconst, indices='all', columns=grid_size, x_size=0.5,y_size=0.5, y_padding=0,spines_alpha=0.1, save_as='05-Latent-morphing')



In [None]:
fidle.end()

---
<img width="80px" src="../fidle/img/logo-paysage.svg"></img>