from tensorflow.keras.callbacks import Callback
import numpy as np
import matplotlib.pyplot as plt
import os

class ImagesCallback(Callback):
    
    def __init__(self, initial_epoch=0, batch_periodicity=1000, vae=None):
        self.epoch             = initial_epoch
        self.batch_periodicity = batch_periodicity
        self.vae               = vae
        self.images_dir        = vae.run_directory+'/images'

    def on_train_batch_end(self, batch, logs={}):  
        
        if batch % self.batch_periodicity == 0:
            # ---- Get a random latent point
            z_new   = np.random.normal(size = (1,self.vae.z_dim))
            # ---- Predict an image
            image = self.vae.decoder.predict(np.array(z_new))[0]
            # ---- Squeeze it if monochrome : (lx,ly,1) -> (lx,ly) 
            image = image.squeeze()
            # ---- Save it
            filename=f'{self.images_dir}/img_{self.epoch:05d}_{batch:06d}.jpg'
            if len(image.shape) == 2:
                plt.imsave(filename, image, cmap='gray_r')
            else:
                plt.imsave(filename, image)

    def on_epoch_begin(self, epoch, logs={}):
        self.epoch += 1