Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
Fidle
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Deploy
Releases
Container Registry
Model registry
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
Talks
Fidle
Commits
21789e50
Commit
21789e50
authored
2 years ago
by
Jean-Luc Parouty
Browse files
Options
Downloads
Patches
Plain Diff
Add batch normalisation
parent
379474fd
No related branches found
Branches containing commit
Tags
1.0.1
Tags containing commit
No related merge requests found
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
DCGAN-PyTorch/01-DCGAN-PL.ipynb
+33
-4
33 additions, 4 deletions
DCGAN-PyTorch/01-DCGAN-PL.ipynb
DCGAN-PyTorch/modules/GAN.py
+8
-2
8 additions, 2 deletions
DCGAN-PyTorch/modules/GAN.py
DCGAN-PyTorch/modules/Generators.py
+9
-3
9 additions, 3 deletions
DCGAN-PyTorch/modules/Generators.py
with
50 additions
and
9 deletions
DCGAN-PyTorch/01-DCGAN-PL.ipynb
+
33
−
4
View file @
21789e50
...
@@ -82,8 +82,8 @@
...
@@ -82,8 +82,8 @@
"generator_class = 'Generator_2'\n",
"generator_class = 'Generator_2'\n",
"discriminator_class = 'Discriminator_1' \n",
"discriminator_class = 'Discriminator_1' \n",
" \n",
" \n",
"scale = .1\n",
"scale = .
0
1\n",
"epochs =
10
\n",
"epochs =
5
\n",
"batch_size = 32\n",
"batch_size = 32\n",
"num_img = 36\n",
"num_img = 36\n",
"fit_verbosity = 2\n",
"fit_verbosity = 2\n",
...
@@ -211,7 +211,8 @@
...
@@ -211,7 +211,8 @@
" batch_size = batch_size, \n",
" batch_size = batch_size, \n",
" latent_dim = latent_dim, \n",
" latent_dim = latent_dim, \n",
" generator_class = generator_class, \n",
" generator_class = generator_class, \n",
" discriminator_class = discriminator_class)"
" discriminator_class = discriminator_class,\n",
" lr=0.0001)"
]
]
},
},
{
{
...
@@ -295,9 +296,37 @@
...
@@ -295,9 +296,37 @@
"metadata": {},
"metadata": {},
"outputs": [],
"outputs": [],
"source": [
"source": [
"
#
gan = GAN.load_from_checkpoint('./run/SHEEP3/
lightning_logs/version_3/checkpoints/epoch=4-step=1980
.ckpt')"
"gan = GAN.load_from_checkpoint('./run/SHEEP3/
models/last
.ckpt')"
]
]
},
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"nb_images = 32\n",
"\n",
"# z = np.random.normal(size=(nb_images,latent_dim))\n",
"\n",
"z = torch.randn(nb_images, latent_dim)\n",
"print('z size : ',z.size())\n",
"\n",
"fake_img = gan.generator.forward(z)\n",
"print('fake_img : ', fake_img.size())\n",
"\n",
"nimg = fake_img.detach().numpy()\n",
"fidle.scrawler.images( nimg.reshape(-1,28,28), indices=range(nb_images), columns=12, x_size=1, y_size=1, \n",
" y_padding=0,spines_alpha=0, save_as='01-Sheeps')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
{
"cell_type": "code",
"cell_type": "code",
"execution_count": null,
"execution_count": null,
...
...
%% Cell type:markdown id: tags:
%% Cell type:markdown id: tags:
# GAN using PyTorch Lightning
# GAN using PyTorch Lightning
See :
See :
-
https://pytorch-lightning.readthedocs.io/en/stable/notebooks/lightning_examples/basic-gan.html
-
https://pytorch-lightning.readthedocs.io/en/stable/notebooks/lightning_examples/basic-gan.html
-
https://www.assemblyai.com/blog/pytorch-lightning-for-dummies/
-
https://www.assemblyai.com/blog/pytorch-lightning-for-dummies/
Note : Need
Note : Need
```
pip install ipywidgets lightning tqdm```
```
pip install ipywidgets lightning tqdm```
%% Cell type:markdown id: tags:
%% Cell type:markdown id: tags:
## Step 1 - Init and parameters
## Step 1 - Init and parameters
#### Python init
#### Python init
%% Cell type:code id: tags:
%% Cell type:code id: tags:
```
python
```
python
import os
import os
import sys
import sys
import numpy as np
import numpy as np
import torch
import torch
import torch.nn as nn
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.functional as F
import torchvision
import torchvision
import torchvision.transforms as transforms
import torchvision.transforms as transforms
from lightning import LightningDataModule, LightningModule, Trainer
from lightning import LightningDataModule, LightningModule, Trainer
from lightning.pytorch.callbacks.progress.tqdm_progress import TQDMProgressBar
from lightning.pytorch.callbacks.progress.tqdm_progress import TQDMProgressBar
from lightning.pytorch.callbacks.progress.base import ProgressBarBase
from lightning.pytorch.callbacks.progress.base import ProgressBarBase
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.loggers.tensorboard import TensorBoardLogger
from lightning.pytorch.loggers.tensorboard import TensorBoardLogger
from tqdm import tqdm
from tqdm import tqdm
from torch.utils.data import DataLoader
from torch.utils.data import DataLoader
import fidle
import fidle
from modules.SmartProgressBar import SmartProgressBar
from modules.SmartProgressBar import SmartProgressBar
from modules.QuickDrawDataModule import QuickDrawDataModule
from modules.QuickDrawDataModule import QuickDrawDataModule
from modules.GAN import GAN
from modules.GAN import GAN
from modules.Generators import
*
from modules.Generators import
*
from modules.Discriminators import
*
from modules.Discriminators import
*
# Init Fidle environment
# Init Fidle environment
run_id, run_dir, datasets_dir = fidle.init('SHEEP3')
run_id, run_dir, datasets_dir = fidle.init('SHEEP3')
```
```
%% Cell type:markdown id: tags:
%% Cell type:markdown id: tags:
#### Few parameters
#### Few parameters
%% Cell type:code id: tags:
%% Cell type:code id: tags:
```
python
```
python
latent_dim = 128
latent_dim = 128
generator_class = 'Generator_2'
generator_class = 'Generator_2'
discriminator_class = 'Discriminator_1'
discriminator_class = 'Discriminator_1'
scale = .1
scale = .
0
1
epochs =
10
epochs =
5
batch_size = 32
batch_size = 32
num_img = 36
num_img = 36
fit_verbosity = 2
fit_verbosity = 2
dataset_file = datasets_dir+'/QuickDraw/origine/sheep.npy'
dataset_file = datasets_dir+'/QuickDraw/origine/sheep.npy'
data_shape = (28,28,1)
data_shape = (28,28,1)
```
```
%% Cell type:markdown id: tags:
%% Cell type:markdown id: tags:
## Step 2 - Get some nice data
## Step 2 - Get some nice data
%% Cell type:markdown id: tags:
%% Cell type:markdown id: tags:
#### Get a Nice DataModule
#### Get a Nice DataModule
Our DataModule is defined in [./modules/QuickDrawDataModule.py](./modules/QuickDrawDataModule.py)
Our DataModule is defined in [./modules/QuickDrawDataModule.py](./modules/QuickDrawDataModule.py)
This is a [LightningDataModule](https://pytorch-lightning.readthedocs.io/en/stable/data/datamodule.html)
This is a [LightningDataModule](https://pytorch-lightning.readthedocs.io/en/stable/data/datamodule.html)
%% Cell type:code id: tags:
%% Cell type:code id: tags:
```
python
```
python
dm = QuickDrawDataModule(dataset_file, scale, batch_size, num_workers=8)
dm = QuickDrawDataModule(dataset_file, scale, batch_size, num_workers=8)
dm.setup()
dm.setup()
```
```
%% Cell type:markdown id: tags:
%% Cell type:markdown id: tags:
#### Have a look
#### Have a look
%% Cell type:code id: tags:
%% Cell type:code id: tags:
```
python
```
python
dl = dm.train_dataloader()
dl = dm.train_dataloader()
batch_data = next(iter(dl))
batch_data = next(iter(dl))
fidle.scrawler.images( batch_data.reshape(-1,28,28), indices=range(batch_size), columns=12, x_size=1, y_size=1,
fidle.scrawler.images( batch_data.reshape(-1,28,28), indices=range(batch_size), columns=12, x_size=1, y_size=1,
y_padding=0,spines_alpha=0, save_as='01-Sheeps')
y_padding=0,spines_alpha=0, save_as='01-Sheeps')
```
```
%% Cell type:markdown id: tags:
%% Cell type:markdown id: tags:
## Step 3 - Get a nice GAN model
## Step 3 - Get a nice GAN model
Our Generators are defined in [./modules/Generators.py](./modules/Generators.py)
Our Generators are defined in [./modules/Generators.py](./modules/Generators.py)
Our Discriminators are defined in [./modules/Discriminators.py](./modules/Discriminators.py)
Our Discriminators are defined in [./modules/Discriminators.py](./modules/Discriminators.py)
Our GAN is defined in [./modules/GAN.py](./modules/GAN.py)
Our GAN is defined in [./modules/GAN.py](./modules/GAN.py)
%% Cell type:markdown id: tags:
%% Cell type:markdown id: tags:
#### Basic test - Just to be sure it (could) works... ;-)
#### Basic test - Just to be sure it (could) works... ;-)
%% Cell type:code id: tags:
%% Cell type:code id: tags:
```
python
```
python
print('
\n
Instantiation :
\n
')
print('
\n
Instantiation :
\n
')
generator = Generator_2(latent_dim=latent_dim, data_shape=data_shape)
generator = Generator_2(latent_dim=latent_dim, data_shape=data_shape)
discriminator = Discriminator_1(latent_dim=latent_dim, data_shape=data_shape)
discriminator = Discriminator_1(latent_dim=latent_dim, data_shape=data_shape)
print('
\n
Few tests :
\n
')
print('
\n
Few tests :
\n
')
z = torch.randn(batch_size, latent_dim)
z = torch.randn(batch_size, latent_dim)
print('z size : ',z.size())
print('z size : ',z.size())
fake_img = generator.forward(z)
fake_img = generator.forward(z)
print('fake_img : ', fake_img.size())
print('fake_img : ', fake_img.size())
p = discriminator.forward(fake_img)
p = discriminator.forward(fake_img)
print('pred fake : ', p.size())
print('pred fake : ', p.size())
print('batch_data : ',batch_data.size())
print('batch_data : ',batch_data.size())
p = discriminator.forward(batch_data)
p = discriminator.forward(batch_data)
print('pred real : ', p.size())
print('pred real : ', p.size())
nimg = fake_img.detach().numpy()
nimg = fake_img.detach().numpy()
fidle.scrawler.images( nimg.reshape(-1,28,28), indices=range(batch_size), columns=12, x_size=1, y_size=1,
fidle.scrawler.images( nimg.reshape(-1,28,28), indices=range(batch_size), columns=12, x_size=1, y_size=1,
y_padding=0,spines_alpha=0, save_as='01-Sheeps')
y_padding=0,spines_alpha=0, save_as='01-Sheeps')
```
```
%% Cell type:markdown id: tags:
%% Cell type:markdown id: tags:
#### GAN model
#### GAN model
To simplify our code, the GAN class is defined separately in the module [./modules/GAN.py](./modules/GAN.py)
To simplify our code, the GAN class is defined separately in the module [./modules/GAN.py](./modules/GAN.py)
Passing the classe names for generator/discriminator by parameter allows to stay modular and to use the PL checkpoints.
Passing the classe names for generator/discriminator by parameter allows to stay modular and to use the PL checkpoints.
%% Cell type:code id: tags:
%% Cell type:code id: tags:
```
python
```
python
gan = GAN( data_shape = data_shape,
gan = GAN( data_shape = data_shape,
batch_size = batch_size,
batch_size = batch_size,
latent_dim = latent_dim,
latent_dim = latent_dim,
generator_class = generator_class,
generator_class = generator_class,
discriminator_class = discriminator_class)
discriminator_class = discriminator_class,
lr=0.0001)
```
```
%% Cell type:markdown id: tags:
%% Cell type:markdown id: tags:
## Step 5 - Train it !
## Step 5 - Train it !
#### Instantiate Callbacks, Logger & co.
#### Instantiate Callbacks, Logger & co.
More about :
More about :
- [Checkpoints](https://pytorch-lightning.readthedocs.io/en/stable/common/checkpointing_basic.html)
- [Checkpoints](https://pytorch-lightning.readthedocs.io/en/stable/common/checkpointing_basic.html)
- [modelCheckpoint](https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.callbacks.ModelCheckpoint.html#pytorch_lightning.callbacks.ModelCheckpoint)
- [modelCheckpoint](https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.callbacks.ModelCheckpoint.html#pytorch_lightning.callbacks.ModelCheckpoint)
%% Cell type:code id: tags:
%% Cell type:code id: tags:
```
python
```
python
# ---- for tensorboard logs
# ---- for tensorboard logs
#
#
logger = TensorBoardLogger( save_dir = f'{run_dir}',
logger = TensorBoardLogger( save_dir = f'{run_dir}',
name = 'tb_logs' )
name = 'tb_logs' )
# ---- To save checkpoints
# ---- To save checkpoints
#
#
callback_checkpoints = ModelCheckpoint( dirpath = f'{run_dir}/models',
callback_checkpoints = ModelCheckpoint( dirpath = f'{run_dir}/models',
filename = 'bestModel',
filename = 'bestModel',
save_top_k = 1,
save_top_k = 1,
save_last = True,
save_last = True,
every_n_epochs = 1,
every_n_epochs = 1,
monitor = "g_loss")
monitor = "g_loss")
# ---- To have a nive progress bar
# ---- To have a nive progress bar
#
#
callback_progressBar = SmartProgressBar(verbosity=2) # Usable evertywhere
callback_progressBar = SmartProgressBar(verbosity=2) # Usable evertywhere
# progress_bar = TQDMProgressBar(refresh_rate=1) # Usable in real jupyter lab (bug in vscode)
# progress_bar = TQDMProgressBar(refresh_rate=1) # Usable in real jupyter lab (bug in vscode)
```
```
%% Cell type:markdown id: tags:
%% Cell type:markdown id: tags:
#### Train it
#### Train it
%% Cell type:code id: tags:
%% Cell type:code id: tags:
```
python
```
python
trainer = Trainer(
trainer = Trainer(
accelerator = "auto",
accelerator = "auto",
# devices = 1 if torch.cuda.is_available() else None, # limiting got iPython runs
# devices = 1 if torch.cuda.is_available() else None, # limiting got iPython runs
max_epochs = epochs,
max_epochs = epochs,
callbacks = [callback_progressBar, callback_checkpoints],
callbacks = [callback_progressBar, callback_checkpoints],
log_every_n_steps = batch_size,
log_every_n_steps = batch_size,
logger = logger
logger = logger
)
)
trainer.fit(gan, dm)
trainer.fit(gan, dm)
```
```
%% Cell type:markdown id: tags:
%% Cell type:markdown id: tags:
## Step 6 - Reload a checkpoint
## Step 6 - Reload a checkpoint
%% Cell type:code id: tags:
%% Cell type:code id: tags:
```
python
```
python
# gan = GAN.load_from_checkpoint('./run/SHEEP3/lightning_logs/version_3/checkpoints/epoch=4-step=1980.ckpt')
gan = GAN.load_from_checkpoint('./run/SHEEP3/models/last.ckpt')
```
%% Cell type:code id: tags:
```
python
nb_images = 32
# z = np.random.normal(size=(nb_images,latent_dim))
z = torch.randn(nb_images, latent_dim)
print('z size : ',z.size())
fake_img = gan.generator.forward(z)
print('fake_img : ', fake_img.size())
nimg = fake_img.detach().numpy()
fidle.scrawler.images( nimg.reshape(-1,28,28), indices=range(nb_images), columns=12, x_size=1, y_size=1,
y_padding=0,spines_alpha=0, save_as='01-Sheeps')
```
%% Cell type:code id: tags:
```
python
```
```
%% Cell type:code id: tags:
%% Cell type:code id: tags:
```
python
```
python
```
```
...
...
This diff is collapsed.
Click to expand it.
DCGAN-PyTorch/modules/GAN.py
+
8
−
2
View file @
21789e50
...
@@ -116,6 +116,8 @@ class GAN(LightningModule):
...
@@ -116,6 +116,8 @@ class GAN(LightningModule):
# These images are reals
# These images are reals
real_labels
=
torch
.
ones
(
batch_size
,
1
)
real_labels
=
torch
.
ones
(
batch_size
,
1
)
# Add random noise to the labels
# real_labels += 0.05 * torch.rand(batch_size,1)
real_labels
=
real_labels
.
type_as
(
imgs
)
real_labels
=
real_labels
.
type_as
(
imgs
)
pred_labels
=
self
.
discriminator
.
forward
(
imgs
)
pred_labels
=
self
.
discriminator
.
forward
(
imgs
)
...
@@ -124,6 +126,8 @@ class GAN(LightningModule):
...
@@ -124,6 +126,8 @@ class GAN(LightningModule):
# These images are fake
# These images are fake
fake_imgs
=
self
.
generator
.
forward
(
z
)
fake_imgs
=
self
.
generator
.
forward
(
z
)
fake_labels
=
torch
.
zeros
(
batch_size
,
1
)
fake_labels
=
torch
.
zeros
(
batch_size
,
1
)
# Add random noise to the labels
# fake_labels += 0.05 * torch.rand(batch_size,1)
fake_labels
=
fake_labels
.
type_as
(
imgs
)
fake_labels
=
fake_labels
.
type_as
(
imgs
)
fake_loss
=
self
.
adversarial_loss
(
self
.
discriminator
(
fake_imgs
.
detach
()),
fake_labels
)
fake_loss
=
self
.
adversarial_loss
(
self
.
discriminator
(
fake_imgs
.
detach
()),
fake_labels
)
...
@@ -143,8 +147,10 @@ class GAN(LightningModule):
...
@@ -143,8 +147,10 @@ class GAN(LightningModule):
# With a GAN, we need 2 separate optimizer.
# With a GAN, we need 2 separate optimizer.
# opt_g to optimize the generator #0
# opt_g to optimize the generator #0
# opt_d to optimize the discriminator #1
# opt_d to optimize the discriminator #1
opt_g
=
torch
.
optim
.
Adam
(
self
.
generator
.
parameters
(),
lr
=
lr
,
betas
=
(
b1
,
b2
))
# opt_g = torch.optim.Adam(self.generator.parameters(), lr=lr, betas=(b1, b2))
opt_d
=
torch
.
optim
.
Adam
(
self
.
discriminator
.
parameters
(),
lr
=
lr
,
betas
=
(
b1
,
b2
))
# opt_d = torch.optim.Adam(self.discriminator.parameters(), lr=lr, betas=(b1, b2),)
opt_g
=
torch
.
optim
.
Adam
(
self
.
generator
.
parameters
(),
lr
=
lr
)
opt_d
=
torch
.
optim
.
Adam
(
self
.
discriminator
.
parameters
(),
lr
=
lr
)
return
[
opt_g
,
opt_d
],
[]
return
[
opt_g
,
opt_d
],
[]
...
...
This diff is collapsed.
Click to expand it.
DCGAN-PyTorch/modules/Generators.py
+
9
−
3
View file @
21789e50
...
@@ -67,12 +67,16 @@ class Generator_2(nn.Module):
...
@@ -67,12 +67,16 @@ class Generator_2(nn.Module):
nn
.
Linear
(
latent_dim
,
7
*
7
*
64
),
nn
.
Linear
(
latent_dim
,
7
*
7
*
64
),
nn
.
Unflatten
(
1
,
(
64
,
7
,
7
)),
nn
.
Unflatten
(
1
,
(
64
,
7
,
7
)),
nn
.
UpsamplingBilinear2d
(
scale_factor
=
2
),
nn
.
UpsamplingNearest2d
(
scale_factor
=
2
),
# nn.UpsamplingBilinear2d( scale_factor=2 ),
nn
.
Conv2d
(
64
,
128
,
(
3
,
3
),
stride
=
(
1
,
1
),
padding
=
(
1
,
1
)
),
nn
.
Conv2d
(
64
,
128
,
(
3
,
3
),
stride
=
(
1
,
1
),
padding
=
(
1
,
1
)
),
nn
.
BatchNorm2d
(
128
),
nn
.
ReLU
(),
nn
.
ReLU
(),
nn
.
UpsamplingBilinear2d
(
scale_factor
=
2
),
nn
.
UpsamplingNearest2d
(
scale_factor
=
2
),
# nn.UpsamplingBilinear2d( scale_factor=2 ),
nn
.
Conv2d
(
128
,
256
,
(
3
,
3
),
stride
=
(
1
,
1
),
padding
=
(
1
,
1
)),
nn
.
Conv2d
(
128
,
256
,
(
3
,
3
),
stride
=
(
1
,
1
),
padding
=
(
1
,
1
)),
nn
.
BatchNorm2d
(
256
),
nn
.
ReLU
(),
nn
.
ReLU
(),
nn
.
Conv2d
(
256
,
1
,
(
5
,
5
),
stride
=
(
1
,
1
),
padding
=
(
2
,
2
)),
nn
.
Conv2d
(
256
,
1
,
(
5
,
5
),
stride
=
(
1
,
1
),
padding
=
(
2
,
2
)),
...
@@ -82,6 +86,8 @@ class Generator_2(nn.Module):
...
@@ -82,6 +86,8 @@ class Generator_2(nn.Module):
def
forward
(
self
,
z
):
def
forward
(
self
,
z
):
img
=
self
.
model
(
z
)
img
=
self
.
model
(
z
)
img
=
img
.
view
(
img
.
size
(
0
),
*
self
.
img_shape
)
img
=
img
.
view
(
img
.
size
(
0
),
*
self
.
img_shape
)
# batch_size x 1 x W x H => batch_size x W x H x 1
return
img
return
img
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment