Skip to content
Snippets Groups Projects

Change VAE

Former-commit-id: 7d55dedb
parent 2cdfaad9
No related branches found
No related tags found
No related merge requests found
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
Variational AutoEncoder Variational AutoEncoder
======================= =======================
--- ---
Formation Introduction au Deep Learning (FIDLE) - S. Arias, E. Maldonado, JL. Parouty - CNRS/SARI/DEVLOG - 2020 Formation Introduction au Deep Learning (FIDLE) - S. Arias, E. Maldonado, JL. Parouty - CNRS/SARI/DEVLOG - 2020
## Variational AutoEncoder (VAE), with MNIST Dataset ## Variational AutoEncoder (VAE), with MNIST Dataset
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
## Step 1 - Init python stuff ## Step 1 - Init python stuff
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
import tensorflow.keras as keras import tensorflow.keras as keras
import tensorflow.keras.datasets.mnist as mnist import tensorflow.keras.datasets.mnist as mnist
import modules.vae import modules.vae
# from modules.vae import VariationalAutoencoder
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import matplotlib import matplotlib
import seaborn as sns import seaborn as sns
import os,sys,h5py,json import os,sys,h5py,json
from importlib import reload from importlib import reload
sys.path.append('..') sys.path.append('..')
import fidle.pwk as ooo import fidle.pwk as ooo
reload(ooo)
ooo.init() ooo.init()
``` ```
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
## Step 2 - Get data ## Step 2 - Get data
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
(x_train, y_train), (x_test, y_test) = mnist.load_data() (x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = x_train.astype('float32') / 255. x_train = x_train.astype('float32') / 255.
x_train = np.expand_dims(x_train, axis=3) x_train = np.expand_dims(x_train, axis=3)
x_test = x_test.astype('float32') / 255. x_test = x_test.astype('float32') / 255.
x_test = np.expand_dims(x_test, axis=3) x_test = np.expand_dims(x_test, axis=3)
print(x_train.shape) print(x_train.shape)
print(x_test.shape) print(x_test.shape)
``` ```
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
## Step 3 - Get VAE model ## Step 3 - Get VAE model
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
reload(modules.vae) # reload(modules.vae)
reload(modules.callbacks) # reload(modules.callbacks)
tag = '000' tag = '000'
input_shape = (28,28,1) input_shape = (28,28,1)
z_dim = 2 z_dim = 2
verbose = 0 verbose = 0
encoder= [ {'type':'Conv2D', 'filters':32, 'kernel_size':(3,3), 'strides':1, 'padding':'same', 'activation':'relu'}, encoder= [ {'type':'Conv2D', 'filters':32, 'kernel_size':(3,3), 'strides':1, 'padding':'same', 'activation':'relu'},
{'type':'Conv2D', 'filters':64, 'kernel_size':(3,3), 'strides':2, 'padding':'same', 'activation':'relu'}, {'type':'Conv2D', 'filters':64, 'kernel_size':(3,3), 'strides':2, 'padding':'same', 'activation':'relu'},
{'type':'Conv2D', 'filters':64, 'kernel_size':(3,3), 'strides':2, 'padding':'same', 'activation':'relu'}, {'type':'Conv2D', 'filters':64, 'kernel_size':(3,3), 'strides':2, 'padding':'same', 'activation':'relu'},
{'type':'Conv2D', 'filters':64, 'kernel_size':(3,3), 'strides':1, 'padding':'same', 'activation':'relu'} {'type':'Conv2D', 'filters':64, 'kernel_size':(3,3), 'strides':1, 'padding':'same', 'activation':'relu'}
] ]
decoder= [ {'type':'Conv2DT', 'filters':64, 'kernel_size':(3,3), 'strides':1, 'padding':'same', 'activation':'relu'}, decoder= [ {'type':'Conv2DT', 'filters':64, 'kernel_size':(3,3), 'strides':1, 'padding':'same', 'activation':'relu'},
{'type':'Conv2DT', 'filters':64, 'kernel_size':(3,3), 'strides':2, 'padding':'same', 'activation':'relu'}, {'type':'Conv2DT', 'filters':64, 'kernel_size':(3,3), 'strides':2, 'padding':'same', 'activation':'relu'},
{'type':'Conv2DT', 'filters':32, 'kernel_size':(3,3), 'strides':2, 'padding':'same', 'activation':'relu'}, {'type':'Conv2DT', 'filters':32, 'kernel_size':(3,3), 'strides':2, 'padding':'same', 'activation':'relu'},
{'type':'Conv2DT', 'filters':1, 'kernel_size':(3,3), 'strides':1, 'padding':'same', 'activation':'sigmoid'} {'type':'Conv2DT', 'filters':1, 'kernel_size':(3,3), 'strides':1, 'padding':'same', 'activation':'sigmoid'}
] ]
vae = modules.vae.VariationalAutoencoder(input_shape = input_shape, vae = modules.vae.VariationalAutoencoder(input_shape = input_shape,
encoder_layers = encoder, encoder_layers = encoder,
decoder_layers = decoder, decoder_layers = decoder,
z_dim = z_dim, z_dim = z_dim,
verbose = verbose, verbose = verbose,
run_tag = tag) run_tag = tag)
``` ```
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
## Step 4 - Compile it ## Step 4 - Compile it
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
learning_rate = 0.0005 learning_rate = 0.0005
r_loss_factor = 1000 r_loss_factor = 1000
vae.compile(learning_rate, r_loss_factor) vae.compile(learning_rate, r_loss_factor)
``` ```
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
## Step 5 - Train ## Step 5 - Train
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
batch_size = 100 batch_size = 100
epochs = 200 epochs = 200
batch_periodicity = 1000 image_periodicity = 1 # in epoch
chkpt_periodicity = 2 # in epoch
initial_epoch = 0 initial_epoch = 0
dataset_size = 0.1 dataset_size = 1
``` ```
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
vae.train(x_train, vae.train(x_train,
x_test, x_test,
batch_size = batch_size, batch_size = batch_size,
epochs = epochs, epochs = epochs,
batch_periodicity = batch_periodicity, image_periodicity = image_periodicity,
chkpt_periodicity = chkpt_periodicity,
initial_epoch = initial_epoch, initial_epoch = initial_epoch,
dataset_size = dataset_size, dataset_size = dataset_size,
lr_decay = 1 lr_decay = 1
) )
``` ```
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
``` ```
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
``` ```
......
...@@ -5,12 +5,15 @@ import os ...@@ -5,12 +5,15 @@ import os
class ImagesCallback(Callback): class ImagesCallback(Callback):
def __init__(self, initial_epoch=0, batch_periodicity=1000, vae=None): def __init__(self, initial_epoch=0, image_periodicity=1, vae=None):
self.epoch = initial_epoch self.epoch = initial_epoch
self.batch_periodicity = batch_periodicity self.image_periodicity = image_periodicity
self.vae = vae self.vae = vae
self.images_dir = vae.run_directory+'/images' self.images_dir = vae.run_directory+'/images'
batch_per_epochs = int(vae.n_train / vae.batch_size)
self.batch_periodicity = batch_per_epochs*image_periodicity
def on_train_batch_end(self, batch, logs={}): def on_train_batch_end(self, batch, logs={}):
if batch % self.batch_periodicity == 0: if batch % self.batch_periodicity == 0:
......
...@@ -144,8 +144,10 @@ class VariationalAutoencoder(): ...@@ -144,8 +144,10 @@ class VariationalAutoencoder():
def train(self, def train(self,
x_train,x_test, x_train,x_test,
batch_size=32, epochs=200, batch_size=32,
batch_periodicity=100, epochs=200,
image_periodicity=1,
chkpt_periodicity=2,
initial_epoch=0, initial_epoch=0,
dataset_size=1, dataset_size=1,
lr_decay=1): lr_decay=1):
...@@ -154,14 +156,18 @@ class VariationalAutoencoder(): ...@@ -154,14 +156,18 @@ class VariationalAutoencoder():
n_train = int(x_train.shape[0] * dataset_size) n_train = int(x_train.shape[0] * dataset_size)
n_test = int(x_test.shape[0] * dataset_size) n_test = int(x_test.shape[0] * dataset_size)
# ---- Need by callbacks
self.n_train = n_train
self.n_test = n_test
self.batch_size = batch_size
# ---- Callbacks # ---- Callbacks
images_callback = modules.callbacks.ImagesCallback(initial_epoch, batch_periodicity, self) images_callback = modules.callbacks.ImagesCallback(initial_epoch, image_periodicity, self)
# lr_sched = step_decay_schedule(initial_lr=self.learning_rate, decay_factor=lr_decay, step_size=1) # lr_sched = step_decay_schedule(initial_lr=self.learning_rate, decay_factor=lr_decay, step_size=1)
filename1 = self.run_directory+"/models/model-{epoch:03d}-{loss:.2f}.h5" filename1 = self.run_directory+"/models/model-{epoch:03d}-{loss:.2f}.h5"
batch_per_epoch = int(len(x_train)/batch_size) checkpoint1 = ModelCheckpoint(filename1, save_freq=n_train*chkpt_periodicity ,verbose=0)
checkpoint1 = ModelCheckpoint(filename1, save_freq=batch_per_epoch*5,verbose=0)
filename2 = self.run_directory+"/models/best_model.h5" filename2 = self.run_directory+"/models/best_model.h5"
checkpoint2 = ModelCheckpoint(filename2, save_best_only=True, mode='min',monitor='val_loss',verbose=0) checkpoint2 = ModelCheckpoint(filename2, save_best_only=True, mode='min',monitor='val_loss',verbose=0)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment