Skip to content
Snippets Groups Projects

Change VAE

parent 6b97bfb6
No related branches found
No related tags found
No related merge requests found
%% Cell type:markdown id: tags:
Variational AutoEncoder
=======================
---
Formation Introduction au Deep Learning (FIDLE) - S. Arias, E. Maldonado, JL. Parouty - CNRS/SARI/DEVLOG - 2020
## Variational AutoEncoder (VAE), with MNIST Dataset
%% Cell type:markdown id: tags:
## Step 1 - Init python stuff
%% Cell type:code id: tags:
``` python
import numpy as np
import tensorflow as tf
import tensorflow.keras as keras
import tensorflow.keras.datasets.mnist as mnist
import modules.vae
# from modules.vae import VariationalAutoencoder
import matplotlib.pyplot as plt
import matplotlib
import seaborn as sns
import os,sys,h5py,json
from importlib import reload
sys.path.append('..')
import fidle.pwk as ooo
reload(ooo)
ooo.init()
```
%% Cell type:markdown id: tags:
## Step 2 - Get data
%% Cell type:code id: tags:
``` python
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = x_train.astype('float32') / 255.
x_train = np.expand_dims(x_train, axis=3)
x_test = x_test.astype('float32') / 255.
x_test = np.expand_dims(x_test, axis=3)
print(x_train.shape)
print(x_test.shape)
```
%% Cell type:markdown id: tags:
## Step 3 - Get VAE model
%% Cell type:code id: tags:
``` python
reload(modules.vae)
reload(modules.callbacks)
# reload(modules.vae)
# reload(modules.callbacks)
tag = '000'
input_shape = (28,28,1)
z_dim = 2
verbose = 0
encoder= [ {'type':'Conv2D', 'filters':32, 'kernel_size':(3,3), 'strides':1, 'padding':'same', 'activation':'relu'},
{'type':'Conv2D', 'filters':64, 'kernel_size':(3,3), 'strides':2, 'padding':'same', 'activation':'relu'},
{'type':'Conv2D', 'filters':64, 'kernel_size':(3,3), 'strides':2, 'padding':'same', 'activation':'relu'},
{'type':'Conv2D', 'filters':64, 'kernel_size':(3,3), 'strides':1, 'padding':'same', 'activation':'relu'}
]
decoder= [ {'type':'Conv2DT', 'filters':64, 'kernel_size':(3,3), 'strides':1, 'padding':'same', 'activation':'relu'},
{'type':'Conv2DT', 'filters':64, 'kernel_size':(3,3), 'strides':2, 'padding':'same', 'activation':'relu'},
{'type':'Conv2DT', 'filters':32, 'kernel_size':(3,3), 'strides':2, 'padding':'same', 'activation':'relu'},
{'type':'Conv2DT', 'filters':1, 'kernel_size':(3,3), 'strides':1, 'padding':'same', 'activation':'sigmoid'}
]
vae = modules.vae.VariationalAutoencoder(input_shape = input_shape,
encoder_layers = encoder,
decoder_layers = decoder,
z_dim = z_dim,
verbose = verbose,
run_tag = tag)
```
%% Cell type:markdown id: tags:
## Step 4 - Compile it
%% Cell type:code id: tags:
``` python
learning_rate = 0.0005
r_loss_factor = 1000
vae.compile(learning_rate, r_loss_factor)
```
%% Cell type:markdown id: tags:
## Step 5 - Train
%% Cell type:code id: tags:
``` python
batch_size = 100
epochs = 200
batch_periodicity = 1000
image_periodicity = 1 # in epoch
chkpt_periodicity = 2 # in epoch
initial_epoch = 0
dataset_size = 0.1
dataset_size = 1
```
%% Cell type:code id: tags:
``` python
vae.train(x_train,
x_test,
batch_size = batch_size,
epochs = epochs,
batch_periodicity = batch_periodicity,
image_periodicity = image_periodicity,
chkpt_periodicity = chkpt_periodicity,
initial_epoch = initial_epoch,
dataset_size = dataset_size,
lr_decay = 1
)
```
%% Cell type:code id: tags:
``` python
```
%% Cell type:code id: tags:
``` python
```
......
......@@ -5,12 +5,15 @@ import os
class ImagesCallback(Callback):
def __init__(self, initial_epoch=0, batch_periodicity=1000, vae=None):
def __init__(self, initial_epoch=0, image_periodicity=1, vae=None):
self.epoch = initial_epoch
self.batch_periodicity = batch_periodicity
self.image_periodicity = image_periodicity
self.vae = vae
self.images_dir = vae.run_directory+'/images'
batch_per_epochs = int(vae.n_train / vae.batch_size)
self.batch_periodicity = batch_per_epochs*image_periodicity
def on_train_batch_end(self, batch, logs={}):
if batch % self.batch_periodicity == 0:
......
......@@ -144,8 +144,10 @@ class VariationalAutoencoder():
def train(self,
x_train,x_test,
batch_size=32, epochs=200,
batch_periodicity=100,
batch_size=32,
epochs=200,
image_periodicity=1,
chkpt_periodicity=2,
initial_epoch=0,
dataset_size=1,
lr_decay=1):
......@@ -154,14 +156,18 @@ class VariationalAutoencoder():
n_train = int(x_train.shape[0] * dataset_size)
n_test = int(x_test.shape[0] * dataset_size)
# ---- Need by callbacks
self.n_train = n_train
self.n_test = n_test
self.batch_size = batch_size
# ---- Callbacks
images_callback = modules.callbacks.ImagesCallback(initial_epoch, batch_periodicity, self)
images_callback = modules.callbacks.ImagesCallback(initial_epoch, image_periodicity, self)
# lr_sched = step_decay_schedule(initial_lr=self.learning_rate, decay_factor=lr_decay, step_size=1)
filename1 = self.run_directory+"/models/model-{epoch:03d}-{loss:.2f}.h5"
batch_per_epoch = int(len(x_train)/batch_size)
checkpoint1 = ModelCheckpoint(filename1, save_freq=batch_per_epoch*5,verbose=0)
checkpoint1 = ModelCheckpoint(filename1, save_freq=n_train*chkpt_periodicity ,verbose=0)
filename2 = self.run_directory+"/models/best_model.h5"
checkpoint2 = ModelCheckpoint(filename2, save_best_only=True, mode='min',monitor='val_loss',verbose=0)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment