Skip to content
Snippets Groups Projects
Commit 1f0df5ff authored by Jean-Luc Parouty's avatar Jean-Luc Parouty
Browse files

Update DCGAN

parent 30323274
No related branches found
No related tags found
No related merge requests found
This diff is collapsed.
...@@ -71,6 +71,9 @@ class DCGAN(keras.Model): ...@@ -71,6 +71,9 @@ class DCGAN(keras.Model):
loss_function : Loss function loss_function : Loss function
''' '''
super(DCGAN, self).compile() super(DCGAN, self).compile()
self.discriminator.compile(optimizer=discriminator_optimizer, loss=loss_function)
self.generator.compile(optimizer=generator_optimizer, loss=loss_function)
self.d_optimizer = discriminator_optimizer self.d_optimizer = discriminator_optimizer
self.g_optimizer = generator_optimizer self.g_optimizer = generator_optimizer
self.loss_fn = loss_function self.loss_fn = loss_function
...@@ -121,11 +124,11 @@ class DCGAN(keras.Model): ...@@ -121,11 +124,11 @@ class DCGAN(keras.Model):
combined_images = tf.concat( [generated_images, real_images], axis=0) combined_images = tf.concat( [generated_images, real_images], axis=0)
# Creation of labels corresponding to real or fake images # Creation of labels corresponding to real or fake images
# 1 is generated, 0 is real # 0 is generated, 1 is real
labels = tf.concat( [tf.ones((batch_size, 1)), tf.zeros((batch_size, 1))], axis=0) labels = tf.concat( [tf.zeros((batch_size, 1)), tf.ones((batch_size, 1))], axis=0)
# Add random noise to the labels - important trick ! # Add random noise to the labels - important trick !
labels += 0.05 * tf.random.uniform(tf.shape(labels)) # labels += 0.05 * tf.random.uniform(tf.shape(labels))
# ---- Train the discriminator ----------------------------- # ---- Train the discriminator -----------------------------
# ---------------------------------------------------------- # ----------------------------------------------------------
...@@ -155,7 +158,7 @@ class DCGAN(keras.Model): ...@@ -155,7 +158,7 @@ class DCGAN(keras.Model):
random_latent_vectors = tf.random.normal(shape=(batch_size, self.latent_dim)) random_latent_vectors = tf.random.normal(shape=(batch_size, self.latent_dim))
# Assemble labels that say all images are real, yes it's a lie ;-) # Assemble labels that say all images are real, yes it's a lie ;-)
misleading_labels = tf.zeros((batch_size, 1)) misleading_labels = tf.ones((batch_size, 1))
# ---- Train the generator --------------------------------- # ---- Train the generator ---------------------------------
# ---------------------------------------------------------- # ----------------------------------------------------------
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment