From 9e4e9e88fa44d5d96ba5ab3e0f537effbb1d437c Mon Sep 17 00:00:00 2001
From: "Jean-Luc Parouty Jean-Luc.Parouty@simap.grenoble-inp.fr"
 <paroutyj@f-dahu.u-ga.fr>
Date: Mon, 3 Feb 2020 20:45:11 +0100
Subject: [PATCH] Change VAE

Former-commit-id: 7d55dedb383b84d4bfa6e2817f955b8daaf099b1
---
 VAE/01-VAE with MNIST.ipynb | 14 +++++++-------
 VAE/modules/callbacks.py    |  9 ++++++---
 VAE/modules/vae.py          | 16 +++++++++++-----
 3 files changed, 24 insertions(+), 15 deletions(-)

diff --git a/VAE/01-VAE with MNIST.ipynb b/VAE/01-VAE with MNIST.ipynb
index 5c0c2b9..5c74509 100644
--- a/VAE/01-VAE with MNIST.ipynb	
+++ b/VAE/01-VAE with MNIST.ipynb	
@@ -33,7 +33,6 @@
     "import tensorflow.keras.datasets.mnist as mnist\n",
     "\n",
     "import modules.vae\n",
-    "# from modules.vae import VariationalAutoencoder\n",
     "\n",
     "import matplotlib.pyplot as plt\n",
     "import matplotlib\n",
@@ -46,7 +45,6 @@
     "sys.path.append('..')\n",
     "import fidle.pwk as ooo\n",
     "\n",
-    "reload(ooo)\n",
     "ooo.init()"
    ]
   },
@@ -86,8 +84,8 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "reload(modules.vae)\n",
-    "reload(modules.callbacks)\n",
+    "# reload(modules.vae)\n",
+    "# reload(modules.callbacks)\n",
     "\n",
     "tag = '000'\n",
     "\n",
@@ -149,9 +147,10 @@
    "source": [
     "batch_size        = 100\n",
     "epochs            = 200\n",
-    "batch_periodicity = 1000\n",
+    "image_periodicity = 1      # in epoch\n",
+    "chkpt_periodicity = 2        # in epoch\n",
     "initial_epoch     = 0\n",
-    "dataset_size      = 0.1"
+    "dataset_size      = 1"
    ]
   },
   {
@@ -164,7 +163,8 @@
     "          x_test,\n",
     "          batch_size        = batch_size, \n",
     "          epochs            = epochs,\n",
-    "          batch_periodicity = batch_periodicity,\n",
+    "          image_periodicity = image_periodicity,\n",
+    "          chkpt_periodicity = chkpt_periodicity,\n",
     "          initial_epoch     = initial_epoch,\n",
     "          dataset_size      = dataset_size,\n",
     "          lr_decay          = 1\n",
diff --git a/VAE/modules/callbacks.py b/VAE/modules/callbacks.py
index a7bbe1a..63adb23 100644
--- a/VAE/modules/callbacks.py
+++ b/VAE/modules/callbacks.py
@@ -5,12 +5,15 @@ import os
 
 class ImagesCallback(Callback):
     
-    def __init__(self, initial_epoch=0, batch_periodicity=1000, vae=None):
+    def __init__(self, initial_epoch=0, image_periodicity=1, vae=None):
         self.epoch             = initial_epoch
-        self.batch_periodicity = batch_periodicity
+        self.image_periodicity = image_periodicity
         self.vae               = vae
         self.images_dir        = vae.run_directory+'/images'
-
+        batch_per_epochs       = int(vae.n_train / vae.batch_size)
+        self.batch_periodicity = batch_per_epochs*image_periodicity
+        
+        
     def on_train_batch_end(self, batch, logs={}):  
         
         if batch % self.batch_periodicity == 0:
diff --git a/VAE/modules/vae.py b/VAE/modules/vae.py
index 8f09c47..f9efa3b 100644
--- a/VAE/modules/vae.py
+++ b/VAE/modules/vae.py
@@ -144,8 +144,10 @@ class VariationalAutoencoder():
     
     def train(self, 
               x_train,x_test,
-              batch_size=32, epochs=200, 
-              batch_periodicity=100, 
+              batch_size=32, 
+              epochs=200, 
+              image_periodicity=1,
+              chkpt_periodicity=2,
               initial_epoch=0,
               dataset_size=1,
               lr_decay=1):
@@ -154,14 +156,18 @@ class VariationalAutoencoder():
         n_train = int(x_train.shape[0] * dataset_size)
         n_test  = int(x_test.shape[0]  * dataset_size)
 
+        # ---- Need by callbacks
+        self.n_train    = n_train
+        self.n_test     = n_test
+        self.batch_size = batch_size
+        
         # ---- Callbacks
-        images_callback = modules.callbacks.ImagesCallback(initial_epoch, batch_periodicity, self)
+        images_callback = modules.callbacks.ImagesCallback(initial_epoch, image_periodicity, self)
         
 #         lr_sched = step_decay_schedule(initial_lr=self.learning_rate, decay_factor=lr_decay, step_size=1)
         
         filename1 = self.run_directory+"/models/model-{epoch:03d}-{loss:.2f}.h5"
-        batch_per_epoch = int(len(x_train)/batch_size)
-        checkpoint1 = ModelCheckpoint(filename1, save_freq=batch_per_epoch*5,verbose=0)
+        checkpoint1 = ModelCheckpoint(filename1, save_freq=n_train*chkpt_periodicity ,verbose=0)
 
         filename2 = self.run_directory+"/models/best_model.h5"
         checkpoint2 = ModelCheckpoint(filename2, save_best_only=True, mode='min',monitor='val_loss',verbose=0)
-- 
GitLab