Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • daconcea/fidle
  • bossardl/fidle
  • Julie.Remenant/fidle
  • abijolao/fidle
  • monsimau/fidle
  • karkars/fidle
  • guilgautier/fidle
  • cailletr/fidle
  • talks/fidle
9 results
Show changes
This diff is collapsed.
This diff is collapsed.
from tensorflow.keras.callbacks import Callback, LearningRateScheduler
import numpy as np
import matplotlib.pyplot as plt
import os
class ImagesCallback(Callback):
def __init__(self, initial_epoch=0, image_periodicity=1, vae=None):
self.epoch = initial_epoch
self.image_periodicity = image_periodicity
self.vae = vae
self.images_dir = vae.run_directory+'/images'
batch_per_epochs = int(vae.n_train / vae.batch_size)
self.batch_periodicity = batch_per_epochs*image_periodicity
def on_train_batch_end(self, batch, logs={}):
if batch % self.batch_periodicity == 0:
# ---- Get a random latent point
z_new = np.random.normal(size = (1,self.vae.z_dim))
# ---- Predict an image
image = self.vae.decoder.predict(np.array(z_new))[0]
# ---- Squeeze it if monochrome : (lx,ly,1) -> (lx,ly)
image = image.squeeze()
# ---- Save it
filename=f'{self.images_dir}/img_{self.epoch:05d}_{batch:06d}.jpg'
if len(image.shape) == 2:
plt.imsave(filename, image, cmap='gray_r')
else:
plt.imsave(filename, image)
def on_epoch_begin(self, epoch, logs={}):
self.epoch += 1
def step_decay_schedule(initial_lr, decay_factor=0.5, step_size=1):
'''
Wrapper function to create a LearningRateScheduler with step decay schedule.
'''
def schedule(epoch):
new_lr = initial_lr * (decay_factor ** np.floor(epoch/step_size))
return new_lr
return LearningRateScheduler(schedule)
\ No newline at end of file
import numpy as np
import tensorflow as tf
import tensorflow.keras.datasets.mnist as mnist
def load_MNIST():
# ---- Get data
(x_train, y_train), (x_test, y_test) = mnist.load_data()
# ---- Normalization
x_train = x_train.astype('float32') / 255.
x_test = x_test.astype( 'float32') / 255.
# ---- Reshape : (28,28) -> (28,28,1)
x_train = np.expand_dims(x_train, axis=3)
x_test = np.expand_dims(x_test, axis=3)
print('Dataset loaded.')
print('Resized and normalized.')
print(f'x_train shape : {x_train.shape}\nx_test_shape : {x_test.shape}')
return (x_train,y_train),(x_test,y_test)
\ No newline at end of file
This diff is collapsed.
This diff is collapsed.
VERSION='0.1a'
\ No newline at end of file
This diff is collapsed.
This diff is collapsed.