Skip to content
Snippets Groups Projects
Commit dc0e2b01 authored by Jean-Luc Parouty's avatar Jean-Luc Parouty
Browse files

Update AE with MNIST dataset

parent 606a540c
No related branches found
No related tags found
1 merge request!5Update style in README
This diff is collapsed.
......@@ -15,22 +15,33 @@ class AE(keras.Model):
def train_step(self, data):
# See :https://keras.io/guides/customizing_what_happens_in_fit/
x, y = data
if isinstance(data, tuple):
data = data[0]
with tf.GradientTape() as tape:
z = self.encoder(data)
reconstruction = self.decoder(z)
reconstruction_loss = tf.reduce_mean( keras.losses.binary_crossentropy(data, reconstruction) )
reconstruction_loss *= 28*28
grads = tape.gradient(reconstruction_loss, self.trainable_weights)
self.optimizer.apply_gradients(zip(grads, self.trainable_weights))
return {
"loss": reconstruction_loss
}
z = self.encoder(x)
y_pred = self.decoder(z)
# Compute the loss value
loss = self.compiled_loss(y, y_pred, regularization_losses=self.losses)
# ---- Compute gradients
#
trainable_vars = self.trainable_variables
gradients = tape.gradient(loss, trainable_vars)
# ---- Update weights
#
self.optimizer.apply_gradients(zip(gradients, trainable_vars))
# ---- Update metrics (includes the metric that tracks the loss)
#
self.compiled_metrics.update_state(y, y_pred)
# ---- Return a dict mapping metric names to current value
#
return {m.name: m.result() for m in self.metrics}
# return {"loss":loss}
def reload(self,filename):
......
......@@ -97,8 +97,8 @@
},
"AE1": {
"path": "/home/pjluc/dev/fidle/VAE",
"start": "Friday 25 December 2020, 19:48:18",
"end": "Friday 25 December 2020, 19:50:24",
"duration": "00:02:06 313ms"
"start": "Saturday 26 December 2020, 12:10:52",
"end": "Saturday 26 December 2020, 12:11:51",
"duration": "00:00:59 384ms"
}
}
\ No newline at end of file
......@@ -589,7 +589,8 @@ def save_fig(filename='auto', png=True, svg=False):
svg : Boolean. Save as svg if True (False)
"""
global _save_figs, _figs_dir, _figs_name, _figs_id
if not _save_figs : return
if filename is None : return
if not _save_figs : return
mkdir(_figs_dir)
if filename=='auto':
path=f'{_figs_dir}/{notebook_id}-{_figs_name}{_figs_id:02d}'
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment