<img width="800px" src="../fidle/img/header.svg"></img>

# <!-- TITLE --> [PLSHEEP3] - A DCGAN to Draw a Sheep, using Pytorch Lightning
<!-- DESC --> "Draw me a sheep", revisited with a DCGAN, using Pytorch Lightning
<!-- AUTHOR : Jean-Luc Parouty (CNRS/SIMaP) -->

## Objectives :
 - Build and train a DCGAN model with the Quick Draw dataset
 - Understanding DCGAN

The [Quick draw dataset](https://quickdraw.withgoogle.com/data) contains about 50.000.000 drawings, made by real people...  
We are using a subset of 117.555 of Sheep drawings  
To get the dataset : [https://github.com/googlecreativelab/quickdraw-dataset](https://github.com/googlecreativelab/quickdraw-dataset)  
Datasets in numpy bitmap file : [https://console.cloud.google.com/storage/quickdraw_dataset/full/numpy_bitmap](https://console.cloud.google.com/storage/quickdraw_dataset/full/numpy_bitmap)   
Sheep dataset : [https://storage.googleapis.com/quickdraw_dataset/full/numpy_bitmap/sheep.npy](https://storage.googleapis.com/quickdraw_dataset/full/numpy_bitmap/sheep.npy) (94.3 Mo)


## What we're going to do :

 - Have a look to the dataset
 - Defining a GAN model
 - Build the model
 - Train it
 - Have a look of the results

## Step 1 - Init and parameters
#### Python init

In [None]:
import os
import sys
import shutil

import numpy as np
import torch
from lightning import Trainer
from lightning.pytorch.callbacks                        import ModelCheckpoint
from lightning.pytorch.loggers.tensorboard              import TensorBoardLogger

import fidle

from modules.QuickDrawDataModule import QuickDrawDataModule

from modules.GAN                 import GAN
from modules.WGANGP              import WGANGP
from modules.Generators          import *
from modules.Discriminators      import *

# Init Fidle environment
run_id, run_dir, datasets_dir = fidle.init('PLSHEEP3')

#### Few parameters
scale=1, epochs=20 : Need 22' on a V100

In [None]:
latent_dim          = 128

gan_name            = 'WGANGP'
generator_name      = 'Generator_2'
discriminator_name  = 'Discriminator_3'
    
scale               = 0.001
epochs              = 4
num_workers         = 2
lr                  = 0.0001
b1                  = 0.5
b2                  = 0.999
lambda_gp           = 10
batch_size          = 64
num_img             = 48
fit_verbosity       = 2
    
dataset_file        = datasets_dir+'/QuickDraw/origine/sheep.npy' 
data_shape          = (28,28,1)

Override parameters (batch mode) - Just forget this cell

In [None]:
fidle.override('latent_dim', 'gan_name', 'generator_name', 'discriminator_name')  
fidle.override('epochs', 'lr', 'b1', 'b2', 'batch_size', 'num_img', 'fit_verbosity')
fidle.override('dataset_file', 'data_shape', 'scale', 'num_workers' )

#### Cleaning

In [None]:
# You can comment these lines to keep each run...
shutil.rmtree(f'{run_dir}/figs', ignore_errors=True)
shutil.rmtree(f'{run_dir}/models', ignore_errors=True)
shutil.rmtree(f'{run_dir}/tb_logs', ignore_errors=True)

## Step 2 - Get some nice data

#### Get a Nice DataModule
Our DataModule is defined in [./modules/QuickDrawDataModule.py](./modules/QuickDrawDataModule.py)   
This is a [LightningDataModule](https://pytorch-lightning.readthedocs.io/en/stable/data/datamodule.html)

In [None]:
dm = QuickDrawDataModule(dataset_file, scale, batch_size, num_workers=num_workers)
dm.setup()

#### Have a look

In [None]:
dl         = dm.train_dataloader()
batch_data = next(iter(dl))

fidle.scrawler.images( batch_data.reshape(-1,28,28), indices=range(batch_size), columns=12, x_size=1, y_size=1, 
                       y_padding=0,spines_alpha=0, save_as='01-Sheeps')

## Step 3 - Get a nice GAN model

Our Generators are defined in [./modules/Generators.py](./modules/Generators.py)  
Our Discriminators are defined in [./modules/Discriminators.py](./modules/Discriminators.py)  


Our GANs are defined in :
 - [./modules/GAN.py](./modules/GAN.py)  
 - [./modules/WGANGP.py](./modules/WGANGP.py)  


#### Retrieve class by name
To be very flexible, we just specify class names as parameters.  
The code below retrieves classes from their names.

In [None]:
module=sys.modules['__main__']
Generator_     = getattr(module, generator_name)
Discriminator_ = getattr(module, discriminator_name)
GAN_           = getattr(module, gan_name)

#### Basic test - Just to be sure it (could) works... ;-)

In [None]:
generator     = Generator_(     latent_dim=latent_dim, data_shape=data_shape )
discriminator = Discriminator_( latent_dim=latent_dim, data_shape=data_shape )

print('\nFew tests :\n')
z = torch.randn(batch_size, latent_dim)
print('z size        : ',z.size())

fake_img = generator.forward(z)
print('fake_img      : ', fake_img.size())

p = discriminator.forward(fake_img)
print('pred fake     : ', p.size())

print('batch_data    : ',batch_data.size())

p = discriminator.forward(batch_data)
print('pred real     : ', p.size())

print('\nShow fake images :')
nimg = fake_img.detach().numpy()
fidle.scrawler.images( nimg.reshape(-1,28,28), indices=range(batch_size), columns=12, x_size=1, y_size=1, 
                       y_padding=0,spines_alpha=0, save_as='01-Sheeps')

In [None]:
print('Fake images : ', fake_img.size())
print('Batch size  : ', batch_data.size())
e = torch.distributions.uniform.Uniform(0, 1).sample([batch_size,1])
e = e[:None,None,None]
i = fake_img * e + (1-e)*batch_data

print('\ninterpolate images :')
nimg = i.detach().numpy()
fidle.scrawler.images( nimg.reshape(-1,28,28), indices=range(batch_size), columns=12, x_size=1, y_size=1, 
                       y_padding=0,spines_alpha=0, save_as='01-Sheeps')


#### GAN model
To simplify our code, the GAN class is defined separately in the module [./modules/GAN.py](./modules/GAN.py)  
Passing the classe names for generator/discriminator by parameter allows to stay modular and to use the PL checkpoints.

In [None]:
gan = GAN_( data_shape          = data_shape,
            lr                  = lr,
            b1                  = b1,
            b2                  = b2,
            lambda_gp           = lambda_gp,
            batch_size          = batch_size, 
            latent_dim          = latent_dim, 
            generator_name      = generator_name, 
            discriminator_name  = discriminator_name)

## Step 5 - Train it !
#### Instantiate Callbacks, Logger & co.
More about :
- [Checkpoints](https://pytorch-lightning.readthedocs.io/en/stable/common/checkpointing_basic.html)
- [modelCheckpoint](https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.callbacks.ModelCheckpoint.html#pytorch_lightning.callbacks.ModelCheckpoint)

In [None]:

# ---- for tensorboard logs
#
logger       = TensorBoardLogger(       save_dir       = f'{run_dir}',
                                        name           = 'tb_logs'  )

log_dir = os.path.abspath(f'{run_dir}/tb_logs')
print('To access the logs with tensorboard, use this command line :')
print(f'tensorboard --logdir {log_dir}')

# ---- To save checkpoints
#
callback_checkpoints = ModelCheckpoint( dirpath        = f'{run_dir}/models', 
                                        filename       = 'bestModel', 
                                        save_top_k     = 1, 
                                        save_last      = True,
                                        every_n_epochs = 1, 
                                        monitor        = "g_loss")

#### Train it

In [None]:

trainer = Trainer(
    accelerator        = "auto",
    max_epochs         = epochs,
    callbacks          = [callback_checkpoints],
    log_every_n_steps  = batch_size,
    logger             = logger
)

trainer.fit(gan, dm)

## Step 6 - Reload our best model
Note : 

In [None]:
gan = GAN.load_from_checkpoint(f'{run_dir}/models/bestModel.ckpt')

In [None]:
nb_images = 96

z = torch.randn(nb_images, latent_dim)
print('z size        : ',z.size())

if torch.cuda.is_available(): z=z.cuda()

fake_img = gan.generator.forward(z)
print('fake_img      : ', fake_img.size())

nimg = fake_img.cpu().detach().numpy()
fidle.scrawler.images( nimg.reshape(-1,28,28), indices=range(nb_images), columns=12, x_size=1, y_size=1, 
                       y_padding=0,spines_alpha=0, save_as='01-Sheeps')

In [None]:
fidle.end()

---
<img width="80px" src="../fidle/img/logo-paysage.svg"></img>