from tensorflow.keras.callbacks import Callback, LearningRateScheduler import numpy as np import matplotlib.pyplot as plt import os class ImagesCallback(Callback): def __init__(self, initial_epoch=0, image_periodicity=1, vae=None): self.epoch = initial_epoch self.image_periodicity = image_periodicity self.vae = vae self.images_dir = vae.run_directory+'/images' batch_per_epochs = int(vae.n_train / vae.batch_size) self.batch_periodicity = batch_per_epochs*image_periodicity def on_train_batch_end(self, batch, logs={}): if batch % self.batch_periodicity == 0: # ---- Get a random latent point z_new = np.random.normal(size = (1,self.vae.z_dim)) # ---- Predict an image image = self.vae.decoder.predict(np.array(z_new))[0] # ---- Squeeze it if monochrome : (lx,ly,1) -> (lx,ly) image = image.squeeze() # ---- Save it filename=f'{self.images_dir}/img_{self.epoch:05d}_{batch:06d}.jpg' if len(image.shape) == 2: plt.imsave(filename, image, cmap='gray_r') else: plt.imsave(filename, image) def on_epoch_begin(self, epoch, logs={}): self.epoch += 1 def step_decay_schedule(initial_lr, decay_factor=0.5, step_size=1): ''' Wrapper function to create a LearningRateScheduler with step decay schedule. ''' def schedule(epoch): new_lr = initial_lr * (decay_factor ** np.floor(epoch/step_size)) return new_lr return LearningRateScheduler(schedule)