<img width="800px" src="../fidle/img/00-Fidle-header-01.svg"></img>

# <!-- TITLE --> [SHEEP2] - A WGAN-GP to Draw a Sheep
<!-- DESC --> Episode 2 : Draw me a sheep, revisited with a WGAN-GP
<!-- AUTHOR : Jean-Luc Parouty (CNRS/SIMaP) -->

## Objectives :
 - Build and train a WGAN-GP model with the Quick Draw dataset
 - Understanding WGAN-GP

The [Quick draw dataset](https://quickdraw.withgoogle.com/data) contains about 50.000.000 drawings, made by real people... 
We are using a subset of 117.555 of Sheep drawings 
To get the dataset : [https://github.com/googlecreativelab/quickdraw-dataset](https://github.com/googlecreativelab/quickdraw-dataset) 
Datasets in numpy bitmap file : [https://console.cloud.google.com/storage/quickdraw_dataset/full/numpy_bitmap](https://console.cloud.google.com/storage/quickdraw_dataset/full/numpy_bitmap) 
Sheep dataset : [https://storage.googleapis.com/quickdraw_dataset/full/numpy_bitmap/sheep.npy](https://storage.googleapis.com/quickdraw_dataset/full/numpy_bitmap/sheep.npy) (94.3 Mo)


## What we're going to do :

 - Have a look to the dataset
 - Defining a GAN model
 - Build the model
 - Train it
 - Analyze the results

## Acknowledgements :
Thanks to **François Chollet** who is at the base of this example. 
See : [https://keras.io/examples/](https://keras.io/examples/)


## Step 1 - Init python stuff

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from skimage import io
import sys

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.callbacks import TensorBoard

from modules.models import WGANGP
from modules.callbacks import ImagesCallback

import fidle

# Init Fidle environment
run_id, run_dir, datasets_dir = fidle.init('SHEEP2')

## Step 2 - Parameters
`scale` : With scale=1, we need 5-6 minutes on a GPU V100 ...and >2h on a CPU ! 
`latent_dim` : Latent space dimension, 128 for example ! 
`fit_verbosity` : verbosity during training : 0 = silent, 1 = progress bar, 2 = one line per epoch 
`num_img` : Number of images to visualize

**Notes:**
- The settings below (scale=0.01) allow the notebooks to run on a laptop, but not to get a minimal result! 
- For a decent result, you need something like: scale=1. 
- The convergence being much better, epochs can here remain at epochs=3 :-)

In [None]:
latent_dim = 128

scale = 0.01
epochs = 3
n_critic = 2
batch_size = 64
num_img = 12
fit_verbosity = 1

Override parameters (batch mode) - Just forget this cell

In [None]:
fidle.override('scale', 'latent_dim', 'epochs', 'batch_size', 'num_img', 'fit_verbosity')

## Step 3 - Load dataset and have a look 
Load sheeps as numpy bitmaps...

In [None]:
# Load dataset
x_data = np.load(datasets_dir+'/QuickDraw/origine/sheep.npy')
print('Original dataset shape : ',x_data.shape)

# Rescale
n=int(scale*len(x_data))
x_data = x_data[:n]
print('Rescaled dataset shape : ',x_data.shape)

# Normalize, reshape and shuffle
x_data = x_data/255
x_data = x_data.reshape(-1,28,28,1)
np.random.shuffle(x_data)
print('Final dataset shape : ',x_data.shape)


...and have a look : 
Note : These sheep are sheep drawn ... by real humans!

In [None]:
fidle.scrawler.images( x_data.reshape(-1,28,28), indices=range(72), columns=12, x_size=1, y_size=1, 
 y_padding=0,spines_alpha=0, save_as='01-Sheeps')

## Step 4 - Create a discriminator

In [None]:
inputs = keras.Input(shape=(28, 28, 1))
x = layers.Conv2D(64, kernel_size=4, strides=2, padding="same")(inputs)
x = layers.LeakyReLU(alpha=0.2)(x)
x = layers.Conv2D(128, kernel_size=4, strides=2, padding="same")(x)
x = layers.LeakyReLU(alpha=0.2)(x)
x = layers.Conv2D(128, kernel_size=4, strides=2, padding="same")(x)
x = layers.LeakyReLU(alpha=0.2)(x)
x = layers.Flatten()(x)
x = layers.Dropout(0.2)(x)
c = layers.Dense(1)(x)

discriminator = keras.Model(inputs, c, name="discriminator")
discriminator.summary()

## Step 5 - Create a generator

In [None]:
inputs = keras.Input(shape=(latent_dim,))
x = layers.Dense(7 * 7 * 64)(inputs)
x = layers.Reshape((7, 7, 64))(x)
x = layers.UpSampling2D()(x)
x = layers.Conv2D(128, kernel_size=3, strides=1, padding='same', activation='relu')(x)
x = layers.UpSampling2D()(x)
x = layers.Conv2D(256, kernel_size=3, strides=1, padding='same', activation='relu')(x)
outputs = layers.Conv2D(1, kernel_size=5, strides=1, padding="same", activation='sigmoid')(x)

generator = keras.Model(inputs, outputs, name="generator")
generator.summary()

## Step 6 - Build, compile and train our DCGAN 
Duration : 5' on a V100, with : scale=0.5, epochs=10, n_critic=2
First, clean saved images :

In [None]:
!rm $run_dir/images/*.jpg >/dev/null 2>&1 

Build our model :

In [None]:
gan = WGANGP(discriminator=discriminator, generator=generator, latent_dim=latent_dim, n_critic=n_critic)

In [None]:
gan.compile(
# discriminator_optimizer = keras.optimizers.Adam(learning_rate=0.0001),
# generator_optimizer = keras.optimizers.Adam(learning_rate=0.0001)
 discriminator_optimizer = keras.optimizers.Adam(learning_rate=0.0002, beta_1=0.5, beta_2=0.9),
 generator_optimizer = keras.optimizers.Adam(learning_rate=0.0002, beta_1=0.5, beta_2=0.9)
)

Add a callback to save images, train our DCGAN model and save it :

In [None]:
imagesCallback = ImagesCallback(num_img=num_img, latent_dim=latent_dim, run_dir=f'{run_dir}/images')

history = gan.fit( x_data, 
 epochs=epochs, 
 batch_size=batch_size, 
 callbacks=[imagesCallback], 
 verbose=fit_verbosity )

gan.save(f'{run_dir}/models/model.h5')

## Step 7 - History

In [None]:
fidle.scrawler.history(history, plot={'loss':['d_loss','g_loss']}, save_as='01-history')

In [None]:
images=[]
for epoch in range(0,epochs,1):
 for i in range(num_img):
 filename = f'{run_dir}/images/image-{epoch:03d}-{i:02d}.jpg'
 image = io.imread(filename)
 images.append(image) 

fidle.scrawler.images(images, None, indices='all', columns=num_img, x_size=1,y_size=1, interpolation=None, y_padding=0, spines_alpha=0, save_as='04-learning')

## Step 8 - Generation
Reload our saved model :

In [None]:
gan.reload(f'{run_dir}/models/model.h5')

Generate somes images from latent space :

In [None]:
nb_images = 12*15

z = np.random.normal(size=(nb_images,latent_dim))
images = gan.predict(z, verbose=0)


Plot it :

In [None]:
fidle.scrawler.images(images, None, indices='all', columns=num_img, x_size=1,y_size=1, interpolation=None, y_padding=0, spines_alpha=0, save_as='04-learning')

In [None]:
fidle.end()

---
<img width="80px" src="../fidle/img/00-Fidle-logo-01.svg"></img>