Skip to content
Snippets Groups Projects
Commit 7d140cc6 authored by Jean-Luc Parouty's avatar Jean-Luc Parouty
Browse files

Update 01-DCGAN-PL.ipynb to add log_dir path

parent b4127a07
No related branches found
No related tags found
No related merge requests found
%% Cell type:markdown id: tags:
<img width="800px" src="../fidle/img/00-Fidle-header-01.svg"></img>
# <!-- TITLE --> [SHEEP3] - A DCGAN to Draw a Sheep, with Pytorch Lightning
<!-- DESC --> Episode 1 : Draw me a sheep, revisited with a DCGAN, writing in Pytorch Lightning
<!-- AUTHOR : Jean-Luc Parouty (CNRS/SIMaP) -->
## Objectives :
- Build and train a DCGAN model with the Quick Draw dataset
- Understanding DCGAN
The [Quick draw dataset](https://quickdraw.withgoogle.com/data) contains about 50.000.000 drawings, made by real people...
We are using a subset of 117.555 of Sheep drawings
To get the dataset : [https://github.com/googlecreativelab/quickdraw-dataset](https://github.com/googlecreativelab/quickdraw-dataset)
Datasets in numpy bitmap file : [https://console.cloud.google.com/storage/quickdraw_dataset/full/numpy_bitmap](https://console.cloud.google.com/storage/quickdraw_dataset/full/numpy_bitmap)
Sheep dataset : [https://storage.googleapis.com/quickdraw_dataset/full/numpy_bitmap/sheep.npy](https://storage.googleapis.com/quickdraw_dataset/full/numpy_bitmap/sheep.npy) (94.3 Mo)
## What we're going to do :
- Have a look to the dataset
- Defining a GAN model
- Build the model
- Train it
- Have a look of the results
%% Cell type:markdown id: tags:
## Step 1 - Init and parameters
#### Python init
%% Cell type:code id: tags:
``` python
import os
import sys
import shutil
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from lightning import LightningDataModule, LightningModule, Trainer
from lightning.pytorch.callbacks.progress.tqdm_progress import TQDMProgressBar
from lightning.pytorch.callbacks.progress.base import ProgressBarBase
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.loggers.tensorboard import TensorBoardLogger
from tqdm import tqdm
from torch.utils.data import DataLoader
import fidle
from modules.SmartProgressBar import SmartProgressBar
from modules.QuickDrawDataModule import QuickDrawDataModule
from modules.GAN import GAN
from modules.WGANGP import WGANGP
from modules.Generators import *
from modules.Discriminators import *
# Init Fidle environment
run_id, run_dir, datasets_dir = fidle.init('SHEEP3')
```
%% Cell type:markdown id: tags:
#### Few parameters
%% Cell type:code id: tags:
``` python
latent_dim = 128
gan_class = 'WGANGP'
generator_class = 'Generator_2'
discriminator_class = 'Discriminator_3'
scale = 0.001
epochs = 3
lr = 0.0001
b1 = 0.5
b2 = 0.999
batch_size = 32
num_img = 48
fit_verbosity = 2
dataset_file = datasets_dir+'/QuickDraw/origine/sheep.npy'
data_shape = (28,28,1)
```
%% Cell type:markdown id: tags:
#### Cleaning
%% Cell type:code id: tags:
``` python
# You can comment these lines to keep each run...
shutil.rmtree(f'{run_dir}/figs', ignore_errors=True)
shutil.rmtree(f'{run_dir}/models', ignore_errors=True)
shutil.rmtree(f'{run_dir}/tb_logs', ignore_errors=True)
```
%% Cell type:markdown id: tags:
## Step 2 - Get some nice data
%% Cell type:markdown id: tags:
#### Get a Nice DataModule
Our DataModule is defined in [./modules/QuickDrawDataModule.py](./modules/QuickDrawDataModule.py)
This is a [LightningDataModule](https://pytorch-lightning.readthedocs.io/en/stable/data/datamodule.html)
%% Cell type:code id: tags:
``` python
dm = QuickDrawDataModule(dataset_file, scale, batch_size, num_workers=8)
dm.setup()
```
%% Cell type:markdown id: tags:
#### Have a look
%% Cell type:code id: tags:
``` python
dl = dm.train_dataloader()
batch_data = next(iter(dl))
fidle.scrawler.images( batch_data.reshape(-1,28,28), indices=range(batch_size), columns=12, x_size=1, y_size=1,
y_padding=0,spines_alpha=0, save_as='01-Sheeps')
```
%% Cell type:markdown id: tags:
## Step 3 - Get a nice GAN model
Our Generators are defined in [./modules/Generators.py](./modules/Generators.py)
Our Discriminators are defined in [./modules/Discriminators.py](./modules/Discriminators.py)
Our GAN is defined in [./modules/GAN.py](./modules/GAN.py)
#### Class loader
%% Cell type:code id: tags:
``` python
def get_class(class_name):
module=sys.modules['__main__']
class_ = getattr(module, class_name)
return class_
def get_instance(class_name, **args):
module=sys.modules['__main__']
class_ = getattr(module, class_name)
instance_ = class_(**args)
return instance_
```
%% Cell type:markdown id: tags:
#### Basic test - Just to be sure it (could) works... ;-)
%% Cell type:code id: tags:
``` python
# ---- A little piece of black magic to instantiate a class from its name
#
def get_classByName(class_name, **args):
module=sys.modules['__main__']
class_ = getattr(module, class_name)
instance_ = class_(**args)
return instance_
# ----Get it, and play with them
#
print('\nInstantiation :\n')
Generator_ = get_class(generator_class)
Discriminator_ = get_class(discriminator_class)
generator = Generator_( latent_dim=latent_dim, data_shape=data_shape)
discriminator = Discriminator_( latent_dim=latent_dim, data_shape=data_shape)
print('\nFew tests :\n')
z = torch.randn(batch_size, latent_dim)
print('z size : ',z.size())
fake_img = generator.forward(z)
print('fake_img : ', fake_img.size())
p = discriminator.forward(fake_img)
print('pred fake : ', p.size())
print('batch_data : ',batch_data.size())
p = discriminator.forward(batch_data)
print('pred real : ', p.size())
nimg = fake_img.detach().numpy()
fidle.scrawler.images( nimg.reshape(-1,28,28), indices=range(batch_size), columns=12, x_size=1, y_size=1,
y_padding=0,spines_alpha=0, save_as='01-Sheeps')
```
%% Cell type:code id: tags:
``` python
print(fake_img.size())
print(batch_data.size())
e = torch.distributions.uniform.Uniform(0, 1).sample([32,1])
e = e[:None,None,None]
i = fake_img * e + (1-e)*batch_data
nimg = i.detach().numpy()
fidle.scrawler.images( nimg.reshape(-1,28,28), indices=range(batch_size), columns=12, x_size=1, y_size=1,
y_padding=0,spines_alpha=0, save_as='01-Sheeps')
```
%% Cell type:markdown id: tags:
#### GAN model
To simplify our code, the GAN class is defined separately in the module [./modules/GAN.py](./modules/GAN.py)
Passing the classe names for generator/discriminator by parameter allows to stay modular and to use the PL checkpoints.
%% Cell type:code id: tags:
``` python
GAN_ = get_class(gan_class)
gan = GAN_( data_shape = data_shape,
lr = lr,
b1 = b1,
b2 = b2,
batch_size = batch_size,
latent_dim = latent_dim,
generator_class = generator_class,
discriminator_class = discriminator_class)
```
%% Cell type:markdown id: tags:
## Step 5 - Train it !
#### Instantiate Callbacks, Logger & co.
More about :
- [Checkpoints](https://pytorch-lightning.readthedocs.io/en/stable/common/checkpointing_basic.html)
- [modelCheckpoint](https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.callbacks.ModelCheckpoint.html#pytorch_lightning.callbacks.ModelCheckpoint)
%% Cell type:code id: tags:
``` python
# ---- for tensorboard logs
#
logger = TensorBoardLogger( save_dir = f'{run_dir}',
name = 'tb_logs' )
log_dir = os.path.abspath(f'{run_dir}/tb_logs')
print('To access the logs with tensorboard, use this command line :')
print(f'tensorboard --logdir {log_dir}')
# ---- To save checkpoints
#
callback_checkpoints = ModelCheckpoint( dirpath = f'{run_dir}/models',
filename = 'bestModel',
save_top_k = 1,
save_last = True,
every_n_epochs = 1,
monitor = "g_loss")
# ---- To have a nive progress bar
#
callback_progressBar = SmartProgressBar(verbosity=2) # Usable evertywhere
# progress_bar = TQDMProgressBar(refresh_rate=1) # Usable in real jupyter lab (bug in vscode)
```
%% Cell type:markdown id: tags:
#### Train it
%% Cell type:code id: tags:
``` python
trainer = Trainer(
accelerator = "auto",
max_epochs = epochs,
callbacks = [callback_progressBar, callback_checkpoints],
log_every_n_steps = batch_size,
logger = logger
)
trainer.fit(gan, dm)
```
%% Cell type:markdown id: tags:
## Step 6 - Reload our best model
Note :
%% Cell type:code id: tags:
``` python
gan = WGANGP.load_from_checkpoint('./run/SHEEP3/models/bestModel.ckpt')
```
%% Cell type:code id: tags:
``` python
nb_images = 96
z = torch.randn(nb_images, latent_dim)
print('z size : ',z.size())
fake_img = gan.generator.forward(z)
print('fake_img : ', fake_img.size())
nimg = fake_img.detach().numpy()
fidle.scrawler.images( nimg.reshape(-1,28,28), indices=range(nb_images), columns=12, x_size=1, y_size=1,
y_padding=0,spines_alpha=0, save_as='01-Sheeps')
```
%% Cell type:code id: tags:
``` python
fidle.end()
```
%% Cell type:markdown id: tags:
---
<img width="80px" src="../fidle/img/00-Fidle-logo-01.svg"></img>
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment