Skip to content
Snippets Groups Projects

Add save/load functionality to VAE model

Former-commit-id: bdfef613
parent 665b3434
No related branches found
No related tags found
No related merge requests found
This diff is collapsed.
%% Cell type:markdown id: tags:
Variational AutoEncoder (VAE) with MNIST
========================================
---
Formation Introduction au Deep Learning (FIDLE) - S. Arias, E. Maldonado, JL. Parouty - CNRS/SARI/DEVLOG - 2020
## Episode 1 - Train a model
- Defining a VAE model
- Build the model
- Train it
- Follow the learning process with Tensorboard
%% Cell type:markdown id: tags:
## Step 1 - Init python stuff
%% Cell type:code id: tags:
``` python
import numpy as np
import tensorflow as tf
import tensorflow.keras.datasets.mnist as mnist
import sys, importlib
import modules.vae
importlib.reload(modules.vae)
print('FIDLE 2020 - Variational AutoEncoder (VAE)')
print('TensorFlow version :',tf.__version__)
```
%% Output
FIDLE 2020 - Variational AutoEncoder (VAE)
TensorFlow version : 2.0.0
%% Cell type:markdown id: tags:
## Step 2 - Get data
%% Cell type:code id: tags:
``` python
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = x_train.astype('float32') / 255.
x_train = np.expand_dims(x_train, axis=3)
x_test = x_test.astype('float32') / 255.
x_test = np.expand_dims(x_test, axis=3)
print('Dataset loaded.')
print(f'x_train shape : {x_train.shape}\nx_test_shape : {x_test.shape}')
```
%% Output
Dataset loaded.
x_train shape : (60000, 28, 28, 1)
x_test_shape : (10000, 28, 28, 1)
%% Cell type:markdown id: tags:
## Step 3 - Get VAE model
%% Cell type:code id: tags:
``` python
tag = '004'
input_shape = (28,28,1)
z_dim = 2
verbose = 0
encoder= [ {'type':'Conv2D', 'filters':32, 'kernel_size':(3,3), 'strides':1, 'padding':'same', 'activation':'relu'},
{'type':'Conv2D', 'filters':64, 'kernel_size':(3,3), 'strides':2, 'padding':'same', 'activation':'relu'},
{'type':'Conv2D', 'filters':64, 'kernel_size':(3,3), 'strides':2, 'padding':'same', 'activation':'relu'},
{'type':'Conv2D', 'filters':64, 'kernel_size':(3,3), 'strides':1, 'padding':'same', 'activation':'relu'}
]
decoder= [ {'type':'Conv2DT', 'filters':64, 'kernel_size':(3,3), 'strides':1, 'padding':'same', 'activation':'relu'},
{'type':'Conv2DT', 'filters':64, 'kernel_size':(3,3), 'strides':2, 'padding':'same', 'activation':'relu'},
{'type':'Conv2DT', 'filters':32, 'kernel_size':(3,3), 'strides':2, 'padding':'same', 'activation':'relu'},
{'type':'Conv2DT', 'filters':1, 'kernel_size':(3,3), 'strides':1, 'padding':'same', 'activation':'sigmoid'}
]
vae = modules.vae.VariationalAutoencoder(input_shape = input_shape,
encoder_layers = encoder,
decoder_layers = decoder,
z_dim = z_dim,
verbose = verbose,
run_tag = tag)
```
%% Output
Model initialized.
Outputs will be in : ./run/004
%% Cell type:code id: tags:
``` python
vae.save()
```
%% Output
Config saved in : ./run/004/models/vae_config.json
Model saved in : ./run/004/models/model.h5
%% Cell type:code id: tags:
``` python
vae=modules.vae.VariationalAutoencoder.load('004')
```
%% Output
dict_keys(['input_shape', 'encoder_layers', 'decoder_layers', 'z_dim', 'run_tag', 'verbose'])
Model initialized.
Outputs will be in : ./run/004
Weights loaded from : ./run/004/models/model.h5
%% Cell type:markdown id: tags:
## Step 4 - Compile it
%% Cell type:code id: tags:
``` python
learning_rate = 0.0005
r_loss_factor = 1000
vae.compile(learning_rate, r_loss_factor)
```
%% Output
Compiled.
Optimizer is Adam with learning_rate=0.0005
%% Cell type:markdown id: tags:
## Step 5 - Train
%% Cell type:code id: tags:
``` python
batch_size = 100
epochs = 100
image_periodicity = 1 # for each epoch
chkpt_periodicity = 2 # for each epoch
initial_epoch = 0
dataset_size = 1
```
%% Cell type:code id: tags:
``` python
vae.train(x_train,
x_test,
batch_size = batch_size,
epochs = epochs,
image_periodicity = image_periodicity,
chkpt_periodicity = chkpt_periodicity,
initial_epoch = initial_epoch,
dataset_size = dataset_size
)
```
%% Cell type:code id: tags:
``` python
vae.
```
%% Cell type:markdown id: tags:
Variational AutoEncoder (VAE) with MNIST
========================================
---
Formation Introduction au Deep Learning (FIDLE) - S. Arias, E. Maldonado, JL. Parouty - CNRS/SARI/DEVLOG - 2020
## Episode 2 - Analyse our trained model
- Defining a VAE model
- Build the model
- Train it
- Follow the learning process with Tensorboard
%% Cell type:markdown id: tags:
## Step 1 - Init python stuff
%% Cell type:code id: tags:
``` python
import numpy as np
import tensorflow as tf
import tensorflow.keras.datasets.mnist as mnist
import sys, importlib
import modules.vae
importlib.reload(modules.vae)
print('FIDLE 2020 - Variational AutoEncoder (VAE)')
print('TensorFlow version :',tf.__version__)
```
%% Output
FIDLE 2020 - Variational AutoEncoder (VAE)
TensorFlow version : 2.0.0
%% Cell type:markdown id: tags:
## Step 2 - Get data
%% Cell type:code id: tags:
``` python
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = x_train.astype('float32') / 255.
x_train = np.expand_dims(x_train, axis=3)
x_test = x_test.astype('float32') / 255.
x_test = np.expand_dims(x_test, axis=3)
print('Dataset loaded.')
print(f'x_train shape : {x_train.shape}\nx_test_shape : {x_test.shape}')
```
%% Output
Dataset loaded.
x_train shape : (60000, 28, 28, 1)
x_test_shape : (10000, 28, 28, 1)
%% Cell type:markdown id: tags:
## Step 3 - Load best model
%% Cell type:code id: tags:
``` python
vae
```
%% Cell type:code id: tags:
``` python
end_time = time.time()
dt = end_time-start_time
dth = str(datetime.timedelta(seconds=dt))
print(f'\nTrain duration : {dt:.2f} sec. - {dth:}')
```
%% Cell type:markdown id: tags:
## Step 4 - Compile it
%% Cell type:code id: tags:
``` python
learning_rate = 0.0005
r_loss_factor = 1000
vae.compile(learning_rate, r_loss_factor)
```
%% Cell type:markdown id: tags:
## Step 5 - Train
%% Cell type:code id: tags:
``` python
batch_size = 100
epochs = 200
image_periodicity = 1 # for each epoch
chkpt_periodicity = 2 # for each epoch
initial_epoch = 0
dataset_size = 1
```
%% Cell type:code id: tags:
``` python
vae.train(x_train,
x_test,
batch_size = batch_size,
epochs = epochs,
image_periodicity = image_periodicity,
chkpt_periodicity = chkpt_periodicity,
initial_epoch = initial_epoch,
dataset_size = dataset_size,
lr_decay = 1
)
```
%% Cell type:code id: tags:
``` python
```
%% Cell type:code id: tags:
``` python
```
...@@ -13,20 +13,21 @@ from tensorflow.keras.utils import plot_model ...@@ -13,20 +13,21 @@ from tensorflow.keras.utils import plot_model
import tensorflow.keras.datasets.imdb as imdb import tensorflow.keras.datasets.imdb as imdb
import modules.callbacks import modules.callbacks
import os import os, json, time, datetime
class VariationalAutoencoder(): class VariationalAutoencoder():
def __init__(self, input_shape=None, encoder_layers=None, decoder_layers=None, z_dim=None, run_tag='default', verbose=0): def __init__(self, input_shape=None, encoder_layers=None, decoder_layers=None, z_dim=None, run_tag='000', verbose=0):
self.name = 'Variational AutoEncoder' self.name = 'Variational AutoEncoder'
self.input_shape = input_shape self.input_shape = list(input_shape)
self.encoder_layers = encoder_layers self.encoder_layers = encoder_layers
self.decoder_layers = decoder_layers self.decoder_layers = decoder_layers
self.z_dim = z_dim self.z_dim = z_dim
self.run_tag = str(run_tag)
self.verbose = verbose self.verbose = verbose
self.run_directory = f'./run/{run_tag}' self.run_directory = f'./run/{run_tag}'
...@@ -42,13 +43,14 @@ class VariationalAutoencoder(): ...@@ -42,13 +43,14 @@ class VariationalAutoencoder():
# ---- Add next layers # ---- Add next layers
i=1 i=1
for params in encoder_layers: for l_config in encoder_layers:
t=params['type'] l_type = l_config['type']
params.pop('type') l_params = l_config.copy()
if t=='Conv2D': l_params.pop('type')
layer = Conv2D(**params, name=f"Layer_{i}") if l_type=='Conv2D':
if t=='Dropout': layer = Conv2D(**l_params)
layer = Dropout(**params) if l_type=='Dropout':
layer = Dropout(**l_params)
x = layer(x) x = layer(x)
i+=1 i+=1
...@@ -83,13 +85,14 @@ class VariationalAutoencoder(): ...@@ -83,13 +85,14 @@ class VariationalAutoencoder():
# ---- Add next layers # ---- Add next layers
i=1 i=1
for params in decoder_layers: for l_config in decoder_layers:
t=params['type'] l_type = l_config['type']
params.pop('type') l_params = l_config.copy()
if t=='Conv2DT': l_params.pop('type')
layer = Conv2DTranspose(**params, name=f"Layer_{i}") if l_type=='Conv2DT':
if t=='Dropout': layer = Conv2DTranspose(**l_params)
layer = Dropout(**params) if l_type=='Dropout':
layer = Dropout(**l_params)
x = layer(x) x = layer(x)
i+=1 i+=1
...@@ -140,6 +143,8 @@ class VariationalAutoencoder(): ...@@ -140,6 +143,8 @@ class VariationalAutoencoder():
loss = vae_loss, loss = vae_loss,
metrics = [vae_r_loss, vae_kl_loss], metrics = [vae_r_loss, vae_kl_loss],
experimental_run_tf_function=False) experimental_run_tf_function=False)
print('Compiled.')
print(f'Optimizer is Adam with learning_rate={learning_rate:}')
def train(self, def train(self,
...@@ -165,7 +170,7 @@ class VariationalAutoencoder(): ...@@ -165,7 +170,7 @@ class VariationalAutoencoder():
callbacks_images = modules.callbacks.ImagesCallback(initial_epoch, image_periodicity, self) callbacks_images = modules.callbacks.ImagesCallback(initial_epoch, image_periodicity, self)
# ---- Callback : Learning rate scheduler # ---- Callback : Learning rate scheduler
lr_sched = modules.callbacks.step_decay_schedule(initial_lr=self.learning_rate, decay_factor=lr_decay, step_size=1) #lr_sched = modules.callbacks.step_decay_schedule(initial_lr=self.learning_rate, decay_factor=lr_decay, step_size=1)
# ---- Callback : Checkpoint # ---- Callback : Checkpoint
filename = self.run_directory+"/models/model-{epoch:03d}-{loss:.2f}.h5" filename = self.run_directory+"/models/model-{epoch:03d}-{loss:.2f}.h5"
...@@ -179,17 +184,23 @@ class VariationalAutoencoder(): ...@@ -179,17 +184,23 @@ class VariationalAutoencoder():
dirname = self.run_directory+"/logs" dirname = self.run_directory+"/logs"
callback_tensorboard = TensorBoard(log_dir=dirname, histogram_freq=1) callback_tensorboard = TensorBoard(log_dir=dirname, histogram_freq=1)
callbacks_list = [callbacks_images, callback_chkpts, callback_bestmodel, callback_tensorboard, lr_sched] callbacks_list = [callbacks_images, callback_chkpts, callback_bestmodel, callback_tensorboard]
self.model.fit(x_train[:n_train], x_train[:n_train], # ---- Let's go...
batch_size = batch_size, start_time = time.time()
shuffle = True, self.history = self.model.fit(x_train[:n_train], x_train[:n_train],
epochs = epochs, batch_size = batch_size,
initial_epoch = initial_epoch, shuffle = True,
callbacks = callbacks_list, epochs = epochs,
validation_data = (x_test[:n_test], x_test[:n_test]) initial_epoch = initial_epoch,
) callbacks = callbacks_list,
validation_data = (x_test[:n_test], x_test[:n_test])
)
end_time = time.time()
dt = end_time-start_time
dth = str(datetime.timedelta(seconds=int(dt)))
self.duration = dt
print(f'\nTrain duration : {dt:.2f} sec. - {dth:}')
def plot_model(self): def plot_model(self):
d=self.run_directory+'/figs' d=self.run_directory+'/figs'
...@@ -198,3 +209,39 @@ class VariationalAutoencoder(): ...@@ -198,3 +209,39 @@ class VariationalAutoencoder():
plot_model(self.decoder, to_file=f'{d}/decoder.png', show_shapes = True, show_layer_names = True) plot_model(self.decoder, to_file=f'{d}/decoder.png', show_shapes = True, show_layer_names = True)
def save(self,config='vae_config.json', model='model.h5'):
# ---- Save config in json
if config!=None:
to_save = ['input_shape', 'encoder_layers', 'decoder_layers', 'z_dim', 'run_tag', 'verbose']
data = { i:self.__dict__[i] for i in to_save }
filename = self.run_directory+'/models/'+config
with open(filename, 'w') as outfile:
json.dump(data, outfile)
print(f'Config saved in : {filename}')
# ---- Save model
if model!=None:
filename = self.run_directory+'/models/'+model
self.model.save(filename)
print(f'Model saved in : {filename}')
def load_weights(self,model='model.h5'):
filename = self.run_directory+'/models/'+model
self.model.load_weights(filename)
print(f'Weights loaded from : {filename}')
@classmethod
def load(cls, run_tag='000', config='vae_config.json', model='model.h5'):
# ---- Instantiate a new vae
filename = f'./run/{run_tag}/models/{config}'
with open(filename, 'r') as infile:
params=json.load(infile)
print(params.keys())
# vae=cls( params['input_shape'], params['encoder_layers'], params['decoder_layers'], params['z_dim'], '004', 0)
vae=cls( **params)
# ---- model==None, just return it
if model==None: return vae
# ---- model!=None, get weight
vae.load_weights(model)
return vae
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment