Skip to content
Snippets Groups Projects
Commit e1c159f9 authored by Jean-Luc Parouty's avatar Jean-Luc Parouty
Browse files

Update VAE MNIST - GPU Validation

parent 4de318c1
No related branches found
No related tags found
1 merge request!5Update style in README
This diff is collapsed.
This diff is collapsed.
...@@ -23,10 +23,12 @@ class Sampling(layers.Layer): ...@@ -23,10 +23,12 @@ class Sampling(layers.Layer):
class VAE(keras.Model): class VAE(keras.Model):
def __init__(self, encoder=None, decoder=None, **kwargs): def __init__(self, encoder=None, decoder=None, r_loss_factor=1., **kwargs):
super(VAE, self).__init__(**kwargs) super(VAE, self).__init__(**kwargs)
self.encoder = encoder self.encoder = encoder
self.decoder = decoder self.decoder = decoder
self.r_loss_factor = r_loss_factor
print('r_loss_factor=',self.r_loss_factor)
def train_step(self, data): def train_step(self, data):
...@@ -42,7 +44,7 @@ class VAE(keras.Model): ...@@ -42,7 +44,7 @@ class VAE(keras.Model):
kl_loss = 1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var) kl_loss = 1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var)
kl_loss = tf.reduce_mean(kl_loss) kl_loss = tf.reduce_mean(kl_loss)
kl_loss *= -0.5 kl_loss *= -0.5
total_loss = reconstruction_loss + kl_loss total_loss = self.r_loss_factor*reconstruction_loss + (1-self.r_loss_factor)*kl_loss
grads = tape.gradient(total_loss, self.trainable_weights) grads = tape.gradient(total_loss, self.trainable_weights)
self.optimizer.apply_gradients(zip(grads, self.trainable_weights)) self.optimizer.apply_gradients(zip(grads, self.trainable_weights))
......
...@@ -79,9 +79,9 @@ ...@@ -79,9 +79,9 @@
}, },
"VAE1": { "VAE1": {
"path": "/gpfsdswork/projects/rech/mlh/uja62cb/fidle/VAE", "path": "/gpfsdswork/projects/rech/mlh/uja62cb/fidle/VAE",
"start": "Monday 21 December 2020, 22:14:18", "start": "Wednesday 23 December 2020, 22:58:02",
"end": "", "end": "Wednesday 23 December 2020, 22:59:24",
"duration": "Unfinished..." "duration": "00:01:22 195ms"
}, },
"MNIST1": { "MNIST1": {
"path": "/home/pjluc/dev/fidle/MNIST", "path": "/home/pjluc/dev/fidle/MNIST",
...@@ -91,7 +91,7 @@ ...@@ -91,7 +91,7 @@
}, },
"VAE2": { "VAE2": {
"path": "/gpfsdswork/projects/rech/mlh/uja62cb/fidle/VAE", "path": "/gpfsdswork/projects/rech/mlh/uja62cb/fidle/VAE",
"start": "Monday 21 December 2020, 22:18:12", "start": "Wednesday 23 December 2020, 23:00:15",
"end": "", "end": "",
"duration": "Unfinished..." "duration": "Unfinished..."
} }
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment