diff --git a/VAE/01-VAE with MNIST.ipynb b/VAE/01-VAE with MNIST.ipynb index 5c0c2b9dc4d1a5701459445e2b8ed00134d1c9f5..5c7450927e8956a11faf1a3355494fa52a20a01b 100644 --- a/VAE/01-VAE with MNIST.ipynb +++ b/VAE/01-VAE with MNIST.ipynb @@ -33,7 +33,6 @@ "import tensorflow.keras.datasets.mnist as mnist\n", "\n", "import modules.vae\n", - "# from modules.vae import VariationalAutoencoder\n", "\n", "import matplotlib.pyplot as plt\n", "import matplotlib\n", @@ -46,7 +45,6 @@ "sys.path.append('..')\n", "import fidle.pwk as ooo\n", "\n", - "reload(ooo)\n", "ooo.init()" ] }, @@ -86,8 +84,8 @@ "metadata": {}, "outputs": [], "source": [ - "reload(modules.vae)\n", - "reload(modules.callbacks)\n", + "# reload(modules.vae)\n", + "# reload(modules.callbacks)\n", "\n", "tag = '000'\n", "\n", @@ -149,9 +147,10 @@ "source": [ "batch_size = 100\n", "epochs = 200\n", - "batch_periodicity = 1000\n", + "image_periodicity = 1 # in epoch\n", + "chkpt_periodicity = 2 # in epoch\n", "initial_epoch = 0\n", - "dataset_size = 0.1" + "dataset_size = 1" ] }, { @@ -164,7 +163,8 @@ " x_test,\n", " batch_size = batch_size, \n", " epochs = epochs,\n", - " batch_periodicity = batch_periodicity,\n", + " image_periodicity = image_periodicity,\n", + " chkpt_periodicity = chkpt_periodicity,\n", " initial_epoch = initial_epoch,\n", " dataset_size = dataset_size,\n", " lr_decay = 1\n", diff --git a/VAE/modules/callbacks.py b/VAE/modules/callbacks.py index a7bbe1a178d4f98616a9213dc3794207e813fad3..63adb23aede91492630a7fd5e4b7dfde817b8304 100644 --- a/VAE/modules/callbacks.py +++ b/VAE/modules/callbacks.py @@ -5,12 +5,15 @@ import os class ImagesCallback(Callback): - def __init__(self, initial_epoch=0, batch_periodicity=1000, vae=None): + def __init__(self, initial_epoch=0, image_periodicity=1, vae=None): self.epoch = initial_epoch - self.batch_periodicity = batch_periodicity + self.image_periodicity = image_periodicity self.vae = vae self.images_dir = vae.run_directory+'/images' - + batch_per_epochs = int(vae.n_train / vae.batch_size) + self.batch_periodicity = batch_per_epochs*image_periodicity + + def on_train_batch_end(self, batch, logs={}): if batch % self.batch_periodicity == 0: diff --git a/VAE/modules/vae.py b/VAE/modules/vae.py index 8f09c47cc91cd039774f9222bc5c2bec67b87446..f9efa3b28a5264e49631b3f5b29daef4a7e74df7 100644 --- a/VAE/modules/vae.py +++ b/VAE/modules/vae.py @@ -144,8 +144,10 @@ class VariationalAutoencoder(): def train(self, x_train,x_test, - batch_size=32, epochs=200, - batch_periodicity=100, + batch_size=32, + epochs=200, + image_periodicity=1, + chkpt_periodicity=2, initial_epoch=0, dataset_size=1, lr_decay=1): @@ -154,14 +156,18 @@ class VariationalAutoencoder(): n_train = int(x_train.shape[0] * dataset_size) n_test = int(x_test.shape[0] * dataset_size) + # ---- Need by callbacks + self.n_train = n_train + self.n_test = n_test + self.batch_size = batch_size + # ---- Callbacks - images_callback = modules.callbacks.ImagesCallback(initial_epoch, batch_periodicity, self) + images_callback = modules.callbacks.ImagesCallback(initial_epoch, image_periodicity, self) # lr_sched = step_decay_schedule(initial_lr=self.learning_rate, decay_factor=lr_decay, step_size=1) filename1 = self.run_directory+"/models/model-{epoch:03d}-{loss:.2f}.h5" - batch_per_epoch = int(len(x_train)/batch_size) - checkpoint1 = ModelCheckpoint(filename1, save_freq=batch_per_epoch*5,verbose=0) + checkpoint1 = ModelCheckpoint(filename1, save_freq=n_train*chkpt_periodicity ,verbose=0) filename2 = self.run_directory+"/models/best_model.h5" checkpoint2 = ModelCheckpoint(filename2, save_best_only=True, mode='min',monitor='val_loss',verbose=0)