Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
Fidle
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Deploy
Releases
Container Registry
Model registry
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
Talks
Fidle
Commits
7d140cc6
Commit
7d140cc6
authored
2 years ago
by
Jean-Luc Parouty
Browse files
Options
Downloads
Patches
Plain Diff
Update 01-DCGAN-PL.ipynb to add log_dir path
parent
b4127a07
No related branches found
No related tags found
No related merge requests found
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
DCGAN-PyTorch/01-DCGAN-PL.ipynb
+4
-0
4 additions, 0 deletions
DCGAN-PyTorch/01-DCGAN-PL.ipynb
with
4 additions
and
0 deletions
DCGAN-PyTorch/01-DCGAN-PL.ipynb
+
4
−
0
View file @
7d140cc6
...
...
@@ -338,6 +338,10 @@
"logger = TensorBoardLogger( save_dir = f'{run_dir}',\n",
" name = 'tb_logs' )\n",
"\n",
"log_dir = os.path.abspath(f'{run_dir}/tb_logs')\n",
"print('To access the logs with tensorboard, use this command line :')\n",
"print(f'tensorboard --logdir {log_dir}')\n",
"\n",
"# ---- To save checkpoints\n",
"#\n",
"callback_checkpoints = ModelCheckpoint( dirpath = f'{run_dir}/models', \n",
...
...
%% Cell type:markdown id: tags:
<img
width=
"800px"
src=
"../fidle/img/00-Fidle-header-01.svg"
></img>
# <!-- TITLE --> [SHEEP3] - A DCGAN to Draw a Sheep, with Pytorch Lightning
<!-- DESC -->
Episode 1 : Draw me a sheep, revisited with a DCGAN, writing in Pytorch Lightning
<!-- AUTHOR : Jean-Luc Parouty (CNRS/SIMaP) -->
## Objectives :
-
Build and train a DCGAN model with the Quick Draw dataset
-
Understanding DCGAN
The
[
Quick draw dataset
](
https://quickdraw.withgoogle.com/data
)
contains about 50.000.000 drawings, made by real people...
We are using a subset of 117.555 of Sheep drawings
To get the dataset :
[
https://github.com/googlecreativelab/quickdraw-dataset
](
https://github.com/googlecreativelab/quickdraw-dataset
)
Datasets in numpy bitmap file :
[
https://console.cloud.google.com/storage/quickdraw_dataset/full/numpy_bitmap
](
https://console.cloud.google.com/storage/quickdraw_dataset/full/numpy_bitmap
)
Sheep dataset :
[
https://storage.googleapis.com/quickdraw_dataset/full/numpy_bitmap/sheep.npy
](
https://storage.googleapis.com/quickdraw_dataset/full/numpy_bitmap/sheep.npy
)
(
94.3
Mo)
## What we're going to do :
-
Have a look to the dataset
-
Defining a GAN model
-
Build the model
-
Train it
-
Have a look of the results
%% Cell type:markdown id: tags:
## Step 1 - Init and parameters
#### Python init
%% Cell type:code id: tags:
```
python
import
os
import
sys
import
shutil
import
numpy
as
np
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torchvision
import
torchvision.transforms
as
transforms
from
lightning
import
LightningDataModule
,
LightningModule
,
Trainer
from
lightning.pytorch.callbacks.progress.tqdm_progress
import
TQDMProgressBar
from
lightning.pytorch.callbacks.progress.base
import
ProgressBarBase
from
lightning.pytorch.callbacks
import
ModelCheckpoint
from
lightning.pytorch.loggers.tensorboard
import
TensorBoardLogger
from
tqdm
import
tqdm
from
torch.utils.data
import
DataLoader
import
fidle
from
modules.SmartProgressBar
import
SmartProgressBar
from
modules.QuickDrawDataModule
import
QuickDrawDataModule
from
modules.GAN
import
GAN
from
modules.WGANGP
import
WGANGP
from
modules.Generators
import
*
from
modules.Discriminators
import
*
# Init Fidle environment
run_id
,
run_dir
,
datasets_dir
=
fidle
.
init
(
'
SHEEP3
'
)
```
%% Cell type:markdown id: tags:
#### Few parameters
%% Cell type:code id: tags:
```
python
latent_dim
=
128
gan_class
=
'
WGANGP
'
generator_class
=
'
Generator_2
'
discriminator_class
=
'
Discriminator_3
'
scale
=
0.001
epochs
=
3
lr
=
0.0001
b1
=
0.5
b2
=
0.999
batch_size
=
32
num_img
=
48
fit_verbosity
=
2
dataset_file
=
datasets_dir
+
'
/QuickDraw/origine/sheep.npy
'
data_shape
=
(
28
,
28
,
1
)
```
%% Cell type:markdown id: tags:
#### Cleaning
%% Cell type:code id: tags:
```
python
# You can comment these lines to keep each run...
shutil
.
rmtree
(
f
'
{
run_dir
}
/figs
'
,
ignore_errors
=
True
)
shutil
.
rmtree
(
f
'
{
run_dir
}
/models
'
,
ignore_errors
=
True
)
shutil
.
rmtree
(
f
'
{
run_dir
}
/tb_logs
'
,
ignore_errors
=
True
)
```
%% Cell type:markdown id: tags:
## Step 2 - Get some nice data
%% Cell type:markdown id: tags:
#### Get a Nice DataModule
Our DataModule is defined in
[
./modules/QuickDrawDataModule.py
](
./modules/QuickDrawDataModule.py
)
This is a
[
LightningDataModule
](
https://pytorch-lightning.readthedocs.io/en/stable/data/datamodule.html
)
%% Cell type:code id: tags:
```
python
dm
=
QuickDrawDataModule
(
dataset_file
,
scale
,
batch_size
,
num_workers
=
8
)
dm
.
setup
()
```
%% Cell type:markdown id: tags:
#### Have a look
%% Cell type:code id: tags:
```
python
dl
=
dm
.
train_dataloader
()
batch_data
=
next
(
iter
(
dl
))
fidle
.
scrawler
.
images
(
batch_data
.
reshape
(
-
1
,
28
,
28
),
indices
=
range
(
batch_size
),
columns
=
12
,
x_size
=
1
,
y_size
=
1
,
y_padding
=
0
,
spines_alpha
=
0
,
save_as
=
'
01-Sheeps
'
)
```
%% Cell type:markdown id: tags:
## Step 3 - Get a nice GAN model
Our Generators are defined in
[
./modules/Generators.py
](
./modules/Generators.py
)
Our Discriminators are defined in
[
./modules/Discriminators.py
](
./modules/Discriminators.py
)
Our GAN is defined in
[
./modules/GAN.py
](
./modules/GAN.py
)
#### Class loader
%% Cell type:code id: tags:
```
python
def
get_class
(
class_name
):
module
=
sys
.
modules
[
'
__main__
'
]
class_
=
getattr
(
module
,
class_name
)
return
class_
def
get_instance
(
class_name
,
**
args
):
module
=
sys
.
modules
[
'
__main__
'
]
class_
=
getattr
(
module
,
class_name
)
instance_
=
class_
(
**
args
)
return
instance_
```
%% Cell type:markdown id: tags:
#### Basic test - Just to be sure it (could) works... ;-)
%% Cell type:code id: tags:
```
python
# ---- A little piece of black magic to instantiate a class from its name
#
def
get_classByName
(
class_name
,
**
args
):
module
=
sys
.
modules
[
'
__main__
'
]
class_
=
getattr
(
module
,
class_name
)
instance_
=
class_
(
**
args
)
return
instance_
# ----Get it, and play with them
#
print
(
'
\n
Instantiation :
\n
'
)
Generator_
=
get_class
(
generator_class
)
Discriminator_
=
get_class
(
discriminator_class
)
generator
=
Generator_
(
latent_dim
=
latent_dim
,
data_shape
=
data_shape
)
discriminator
=
Discriminator_
(
latent_dim
=
latent_dim
,
data_shape
=
data_shape
)
print
(
'
\n
Few tests :
\n
'
)
z
=
torch
.
randn
(
batch_size
,
latent_dim
)
print
(
'
z size :
'
,
z
.
size
())
fake_img
=
generator
.
forward
(
z
)
print
(
'
fake_img :
'
,
fake_img
.
size
())
p
=
discriminator
.
forward
(
fake_img
)
print
(
'
pred fake :
'
,
p
.
size
())
print
(
'
batch_data :
'
,
batch_data
.
size
())
p
=
discriminator
.
forward
(
batch_data
)
print
(
'
pred real :
'
,
p
.
size
())
nimg
=
fake_img
.
detach
().
numpy
()
fidle
.
scrawler
.
images
(
nimg
.
reshape
(
-
1
,
28
,
28
),
indices
=
range
(
batch_size
),
columns
=
12
,
x_size
=
1
,
y_size
=
1
,
y_padding
=
0
,
spines_alpha
=
0
,
save_as
=
'
01-Sheeps
'
)
```
%% Cell type:code id: tags:
```
python
print
(
fake_img
.
size
())
print
(
batch_data
.
size
())
e
=
torch
.
distributions
.
uniform
.
Uniform
(
0
,
1
).
sample
([
32
,
1
])
e
=
e
[:
None
,
None
,
None
]
i
=
fake_img
*
e
+
(
1
-
e
)
*
batch_data
nimg
=
i
.
detach
().
numpy
()
fidle
.
scrawler
.
images
(
nimg
.
reshape
(
-
1
,
28
,
28
),
indices
=
range
(
batch_size
),
columns
=
12
,
x_size
=
1
,
y_size
=
1
,
y_padding
=
0
,
spines_alpha
=
0
,
save_as
=
'
01-Sheeps
'
)
```
%% Cell type:markdown id: tags:
#### GAN model
To simplify our code, the GAN class is defined separately in the module
[
./modules/GAN.py
](
./modules/GAN.py
)
Passing the classe names for generator/discriminator by parameter allows to stay modular and to use the PL checkpoints.
%% Cell type:code id: tags:
```
python
GAN_
=
get_class
(
gan_class
)
gan
=
GAN_
(
data_shape
=
data_shape
,
lr
=
lr
,
b1
=
b1
,
b2
=
b2
,
batch_size
=
batch_size
,
latent_dim
=
latent_dim
,
generator_class
=
generator_class
,
discriminator_class
=
discriminator_class
)
```
%% Cell type:markdown id: tags:
## Step 5 - Train it !
#### Instantiate Callbacks, Logger & co.
More about :
-
[
Checkpoints
](
https://pytorch-lightning.readthedocs.io/en/stable/common/checkpointing_basic.html
)
-
[
modelCheckpoint
](
https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.callbacks.ModelCheckpoint.html#pytorch_lightning.callbacks.ModelCheckpoint
)
%% Cell type:code id: tags:
```
python
# ---- for tensorboard logs
#
logger
=
TensorBoardLogger
(
save_dir
=
f
'
{
run_dir
}
'
,
name
=
'
tb_logs
'
)
log_dir
=
os
.
path
.
abspath
(
f
'
{
run_dir
}
/tb_logs
'
)
print
(
'
To access the logs with tensorboard, use this command line :
'
)
print
(
f
'
tensorboard --logdir
{
log_dir
}
'
)
# ---- To save checkpoints
#
callback_checkpoints
=
ModelCheckpoint
(
dirpath
=
f
'
{
run_dir
}
/models
'
,
filename
=
'
bestModel
'
,
save_top_k
=
1
,
save_last
=
True
,
every_n_epochs
=
1
,
monitor
=
"
g_loss
"
)
# ---- To have a nive progress bar
#
callback_progressBar
=
SmartProgressBar
(
verbosity
=
2
)
# Usable evertywhere
# progress_bar = TQDMProgressBar(refresh_rate=1) # Usable in real jupyter lab (bug in vscode)
```
%% Cell type:markdown id: tags:
#### Train it
%% Cell type:code id: tags:
```
python
trainer
=
Trainer
(
accelerator
=
"
auto
"
,
max_epochs
=
epochs
,
callbacks
=
[
callback_progressBar
,
callback_checkpoints
],
log_every_n_steps
=
batch_size
,
logger
=
logger
)
trainer
.
fit
(
gan
,
dm
)
```
%% Cell type:markdown id: tags:
## Step 6 - Reload our best model
Note :
%% Cell type:code id: tags:
```
python
gan
=
WGANGP
.
load_from_checkpoint
(
'
./run/SHEEP3/models/bestModel.ckpt
'
)
```
%% Cell type:code id: tags:
```
python
nb_images
=
96
z
=
torch
.
randn
(
nb_images
,
latent_dim
)
print
(
'
z size :
'
,
z
.
size
())
fake_img
=
gan
.
generator
.
forward
(
z
)
print
(
'
fake_img :
'
,
fake_img
.
size
())
nimg
=
fake_img
.
detach
().
numpy
()
fidle
.
scrawler
.
images
(
nimg
.
reshape
(
-
1
,
28
,
28
),
indices
=
range
(
nb_images
),
columns
=
12
,
x_size
=
1
,
y_size
=
1
,
y_padding
=
0
,
spines_alpha
=
0
,
save_as
=
'
01-Sheeps
'
)
```
%% Cell type:code id: tags:
```
python
fidle
.
end
()
```
%% Cell type:markdown id: tags:
---
<img
width=
"80px"
src=
"../fidle/img/00-Fidle-logo-01.svg"
></img>
...
...
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment