Skip to content
Snippets Groups Projects
vae.py 6.83 KiB
Newer Older
import numpy as np

import tensorflow as tf
import tensorflow.keras as keras
from tensorflow.keras import backend as K
Jean-Luc Parouty's avatar
Jean-Luc Parouty committed
from tensorflow.keras.layers import Input, Conv2D, Flatten, Dense, Conv2DTranspose, Reshape, Lambda
from tensorflow.keras.layers import Activation, BatchNormalization, LeakyReLU, Dropout
from tensorflow.keras.models import Model
Jean-Luc Parouty's avatar
Jean-Luc Parouty committed
from tensorflow.keras.callbacks import ModelCheckpoint 
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.utils import plot_model

import tensorflow.keras.datasets.imdb as imdb

Jean-Luc Parouty's avatar
Jean-Luc Parouty committed
import modules.callbacks
import os



class VariationalAutoencoder():

    
Jean-Luc Parouty's avatar
Jean-Luc Parouty committed
    def __init__(self, input_shape=None, encoder_layers=None, decoder_layers=None, z_dim=None, run_tag='default', verbose=0):
        
        self.name           = 'Variational AutoEncoder'
        self.input_shape    = input_shape
        self.encoder_layers = encoder_layers
        self.decoder_layers = decoder_layers
        self.z_dim          = z_dim
Jean-Luc Parouty's avatar
Jean-Luc Parouty committed
        self.verbose        = verbose
        self.run_directory  = f'./run/{run_tag}'
        
        # ---- Create run directories
        for d in ('','/models','/figs','/logs','/images'):
            os.makedirs(self.run_directory+d, mode=0o750, exist_ok=True)
        
        # ==== Encoder ================================================================
        
        # ---- Input layer
        encoder_input = Input(shape=self.input_shape, name='encoder_input')
        x = encoder_input
        
        # ---- Add next layers
        i=1
        for params in encoder_layers:
            t=params['type']
            params.pop('type')
            if t=='Conv2D':
                layer = Conv2D(**params, name=f"Layer_{i}")
            if t=='Dropout':
                layer = Dropout(**params)
            x = layer(x)
            i+=1
            
        # ---- Flatten
        shape_before_flattening = K.int_shape(x)[1:]
        x = Flatten()(x)
        
Jean-Luc Parouty's avatar
Jean-Luc Parouty committed
        # ---- mu <-> log_var
        self.mu      = Dense(self.z_dim, name='mu')(x)
        self.log_var = Dense(self.z_dim, name='log_var')(x)

        self.encoder_mu_log_var = Model(encoder_input, (self.mu, self.log_var))

        # ---- output layer
        def sampling(args):
            mu, log_var = args
            epsilon = K.random_normal(shape=K.shape(mu), mean=0., stddev=1.)
            return mu + K.exp(log_var / 2) * epsilon

        encoder_output = Lambda(sampling, name='encoder_output')([self.mu, self.log_var])

        self.encoder = Model(encoder_input, encoder_output)

        # ==== Decoder ================================================================

        # ---- Input layer
        decoder_input = Input(shape=(self.z_dim,), name='decoder_input')
        
        # ---- First dense layer
        x = Dense(np.prod(shape_before_flattening))(decoder_input)
        x = Reshape(shape_before_flattening)(x)

        # ---- Add next layers
        i=1
        for params in decoder_layers:
            t=params['type']
            params.pop('type')
            if t=='Conv2DT':
                layer = Conv2DTranspose(**params, name=f"Layer_{i}")
            if t=='Dropout':
                layer = Dropout(**params)
            x = layer(x)
            i+=1

        decoder_output = x

        self.decoder = Model(decoder_input, decoder_output)
        
        # ==== Encoder-Decoder ========================================================
        
        model_input = encoder_input
        model_output = self.decoder(encoder_output)

        self.model = Model(model_input, model_output)

Jean-Luc Parouty's avatar
Jean-Luc Parouty committed
        # ==== Verbosity ==============================================================

        print('Model initialized.')
        print('Outputs will be in : ',self.run_directory)
        
        if verbose>0 :
            print('\n','-'*10,'Encoder','-'*50,'\n')
            self.encoder.summary()
            print('\n','-'*10,'Encoder','-'*50,'\n')
            self.decoder.summary()
            self.plot_model()
        
        
        
    def compile(self, learning_rate, r_loss_factor):
        self.learning_rate = learning_rate
        self.r_loss_factor = r_loss_factor

        def vae_r_loss(y_true, y_pred):
            r_loss = K.mean(K.square(y_true - y_pred), axis = [1,2,3])
            return r_loss_factor * r_loss

        def vae_kl_loss(y_true, y_pred):
            kl_loss =  -0.5 * K.sum(1 + self.log_var - K.square(self.mu) - K.exp(self.log_var), axis = 1)
            return kl_loss

        def vae_loss(y_true, y_pred):
            r_loss = vae_r_loss(y_true, y_pred)
            kl_loss = vae_kl_loss(y_true, y_pred)
            return  r_loss + kl_loss

        optimizer = Adam(lr=learning_rate)
Jean-Luc Parouty's avatar
Jean-Luc Parouty committed
        self.model.compile(optimizer=optimizer, 
                           loss = vae_loss,
                           metrics = [vae_r_loss, vae_kl_loss], 
                           experimental_run_tf_function=False)
    
    
    def train(self, 
              x_train,x_test,
              batch_size=32, 
              epochs=200, 
              image_periodicity=1,
              chkpt_periodicity=2,
Jean-Luc Parouty's avatar
Jean-Luc Parouty committed
              initial_epoch=0,
              dataset_size=1,
              lr_decay=1):

        # ---- Dataset size
        n_train = int(x_train.shape[0] * dataset_size)
        n_test  = int(x_test.shape[0]  * dataset_size)

        # ---- Need by callbacks
        self.n_train    = n_train
        self.n_test     = n_test
        self.batch_size = batch_size
        
Jean-Luc Parouty's avatar
Jean-Luc Parouty committed
        # ---- Callbacks
        images_callback = modules.callbacks.ImagesCallback(initial_epoch, image_periodicity, self)
Jean-Luc Parouty's avatar
Jean-Luc Parouty committed
        
#         lr_sched = step_decay_schedule(initial_lr=self.learning_rate, decay_factor=lr_decay, step_size=1)
        
        filename1 = self.run_directory+"/models/model-{epoch:03d}-{loss:.2f}.h5"
        checkpoint1 = ModelCheckpoint(filename1, save_freq=n_train*chkpt_periodicity ,verbose=0)
Jean-Luc Parouty's avatar
Jean-Luc Parouty committed

        filename2 = self.run_directory+"/models/best_model.h5"
        checkpoint2 = ModelCheckpoint(filename2, save_best_only=True, mode='min',monitor='val_loss',verbose=0)

        callbacks_list = [checkpoint1, checkpoint2, images_callback]

        self.model.fit(x_train[:n_train], x_train[:n_train],
                       batch_size = batch_size,
                       shuffle = True,
                       epochs = epochs,
                       initial_epoch = initial_epoch,
                       callbacks = callbacks_list,
                       validation_data = (x_test[:n_test], x_test[:n_test])
                        )
        
        
    def plot_model(self):
        d=self.run_directory+'/figs'
        plot_model(self.model,   to_file=f'{d}/model.png',   show_shapes = True, show_layer_names = True, expand_nested=True)
        plot_model(self.encoder, to_file=f'{d}/encoder.png', show_shapes = True, show_layer_names = True)
        plot_model(self.decoder, to_file=f'{d}/decoder.png', show_shapes = True, show_layer_names = True)