From 68c5e712761d6d184f86b341490a29dc3ca17c4e Mon Sep 17 00:00:00 2001
From: Jean-Luc Parouty <Jean-Luc.Parouty@simap.grenoble-inp.fr>
Date: Tue, 7 Mar 2023 00:40:04 +0100
Subject: [PATCH] Add WGANGP to DCGAN-PL :-)

---
 DCGAN-PyTorch/01-DCGAN-PL.ipynb |  62 +++++----
 DCGAN-PyTorch/modules/GAN.py    |   4 -
 DCGAN-PyTorch/modules/WGANGP.py | 229 ++++++++++++++++++++++++++++++++
 3 files changed, 268 insertions(+), 27 deletions(-)
 create mode 100644 DCGAN-PyTorch/modules/WGANGP.py

diff --git a/DCGAN-PyTorch/01-DCGAN-PL.ipynb b/DCGAN-PyTorch/01-DCGAN-PL.ipynb
index 294fa50..dd821c7 100644
--- a/DCGAN-PyTorch/01-DCGAN-PL.ipynb
+++ b/DCGAN-PyTorch/01-DCGAN-PL.ipynb
@@ -71,6 +71,7 @@
     "from modules.QuickDrawDataModule import QuickDrawDataModule\n",
     "\n",
     "from modules.GAN                 import GAN\n",
+    "from modules.WGANGP              import WGANGP\n",
     "from modules.Generators          import *\n",
     "from modules.Discriminators      import *\n",
     "\n",
@@ -94,11 +95,12 @@
    "outputs": [],
    "source": [
     "latent_dim          = 128\n",
-    "    \n",
+    "\n",
+    "gan_class           = 'GAN'\n",
     "generator_class     = 'Generator_2'\n",
     "discriminator_class = 'Discriminator_2'    \n",
     "    \n",
-    "scale               = 1\n",
+    "scale               = 0.001\n",
     "epochs              = 3\n",
     "lr                  = 0.0001\n",
     "b1                  = 0.5\n",
@@ -125,9 +127,9 @@
    "outputs": [],
    "source": [
     "# You can comment these lines to keep each run...\n",
-    "# shutil.rmtree(f'{run_dir}/figs', ignore_errors=True)\n",
-    "# shutil.rmtree(f'{run_dir}/models', ignore_errors=True)\n",
-    "# shutil.rmtree(f'{run_dir}/tb_logs', ignore_errors=True)"
+    "shutil.rmtree(f'{run_dir}/figs', ignore_errors=True)\n",
+    "shutil.rmtree(f'{run_dir}/models', ignore_errors=True)\n",
+    "shutil.rmtree(f'{run_dir}/tb_logs', ignore_errors=True)"
    ]
   },
   {
@@ -206,7 +208,7 @@
    },
    "outputs": [],
    "source": [
-    "# ---- A little piece of black magic to instantiate Generator and Discriminator from their class names\n",
+    "# ---- A little piece of black magic to instantiate a class from its name\n",
     "#\n",
     "def get_classByName(class_name, **args):\n",
     "    module=sys.modules['__main__']\n",
@@ -240,6 +242,25 @@
     "                       y_padding=0,spines_alpha=0, save_as='01-Sheeps')"
    ]
   },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "\n",
+    "print(fake_img.size())\n",
+    "print(batch_data.size())\n",
+    "e = torch.distributions.uniform.Uniform(0, 1).sample([32,1])\n",
+    "e = e[:None,None,None]\n",
+    "i = fake_img * e + (1-e)*batch_data\n",
+    "\n",
+    "\n",
+    "nimg = i.detach().numpy()\n",
+    "fidle.scrawler.images( nimg.reshape(-1,28,28), indices=range(batch_size), columns=12, x_size=1, y_size=1, \n",
+    "                       y_padding=0,spines_alpha=0, save_as='01-Sheeps')\n"
+   ]
+  },
   {
    "cell_type": "markdown",
    "metadata": {},
@@ -255,14 +276,14 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "gan = GAN( data_shape          = data_shape,\n",
-    "           lr                  = lr,\n",
-    "           b1                  = b1,\n",
-    "           b2                  = b2,\n",
-    "           batch_size          = batch_size, \n",
-    "           latent_dim          = latent_dim, \n",
-    "           generator_class     = generator_class, \n",
-    "           discriminator_class = discriminator_class)"
+    "gan = WGANGP( data_shape          = data_shape,\n",
+    "              lr                  = lr,\n",
+    "              b1                  = b1,\n",
+    "              b2                  = b2,\n",
+    "              batch_size          = batch_size, \n",
+    "              latent_dim          = latent_dim, \n",
+    "              generator_class     = generator_class, \n",
+    "              discriminator_class = discriminator_class)"
    ]
   },
   {
@@ -346,7 +367,7 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "gan = GAN.load_from_checkpoint('./run/SHEEP3/models/bestModel-v1.ckpt')"
+    "gan = WGANGP.load_from_checkpoint('./run/SHEEP3/models/bestModel.ckpt')"
    ]
   },
   {
@@ -388,9 +409,9 @@
  ],
  "metadata": {
   "kernelspec": {
-   "display_name": "Fidle-Lightning",
+   "display_name": "fidle-env",
    "language": "python",
-   "name": "fidle-lightning"
+   "name": "python3"
   },
   "language_info": {
    "codemirror_mode": {
@@ -402,12 +423,7 @@
    "name": "python",
    "nbconvert_exporter": "python",
    "pygments_lexer": "ipython3",
-   "version": "3.10.8"
-  },
-  "vscode": {
-   "interpreter": {
-    "hash": "b3929042cc22c1274d74e3e946c52b845b57cb6d84f2d591ffe0519b38e4896d"
-   }
+   "version": "3.9.2"
   }
  },
  "nbformat": 4,
diff --git a/DCGAN-PyTorch/modules/GAN.py b/DCGAN-PyTorch/modules/GAN.py
index 4391747..cf5a569 100644
--- a/DCGAN-PyTorch/modules/GAN.py
+++ b/DCGAN-PyTorch/modules/GAN.py
@@ -116,8 +116,6 @@ class GAN(LightningModule):
             
             # These images are reals
             real_labels = torch.ones(batch_size, 1)
-            # Add random noise to the labels
-            # real_labels += 0.05 * torch.rand(batch_size,1)
             real_labels = real_labels.type_as(imgs)
             pred_labels = self.discriminator.forward(imgs)
 
@@ -126,8 +124,6 @@ class GAN(LightningModule):
             # These images are fake
             fake_imgs   = self.generator.forward(z)
             fake_labels = torch.zeros(batch_size, 1)
-            # Add random noise to the labels
-            # fake_labels += 0.05 * torch.rand(batch_size,1)
             fake_labels = fake_labels.type_as(imgs)
 
             fake_loss   = self.adversarial_loss(self.discriminator(fake_imgs.detach()), fake_labels)
diff --git a/DCGAN-PyTorch/modules/WGANGP.py b/DCGAN-PyTorch/modules/WGANGP.py
new file mode 100644
index 0000000..030740b
--- /dev/null
+++ b/DCGAN-PyTorch/modules/WGANGP.py
@@ -0,0 +1,229 @@
+
+# ------------------------------------------------------------------
+#     _____ _     _ _
+#    |  ___(_) __| | | ___
+#    | |_  | |/ _` | |/ _ \
+#    |  _| | | (_| | |  __/
+#    |_|   |_|\__,_|_|\___|                GAN / GAN LigthningModule
+# ------------------------------------------------------------------
+# Formation Introduction au Deep Learning  (FIDLE)
+# CNRS/SARI/DEVLOG MIAI/EFELIA 2023 - https://fidle.cnrs.fr
+# ------------------------------------------------------------------
+# by JL Parouty (feb 2023) - PyTorch Lightning example
+
+
+import sys
+import numpy as np
+import torch
+import torch.nn.functional as F
+import torchvision
+from lightning import LightningModule
+
+
+class WGANGP(LightningModule):
+
+    # -------------------------------------------------------------------------
+    # Init
+    # -------------------------------------------------------------------------
+    #
+    def __init__(
+        self,
+        data_shape          = (None,None,None),
+        latent_dim          = None,
+        lr                  = 0.0002,
+        b1                  = 0.5,
+        b2                  = 0.999,
+        batch_size          = 64,
+        lambda_gp           = 10,
+        generator_class     = None,
+        discriminator_class = None,
+        **kwargs,
+    ):
+        super().__init__()
+
+        print('\n---- WGANGP initialization -----------------------------------------')
+
+        # ---- Hyperparameters
+        #
+        # Enable Lightning to store all the provided arguments under the self.hparams attribute.
+        # These hyperparameters will also be stored within the model checkpoint.
+        #
+        self.save_hyperparameters()
+
+        print('Hyperarameters are :')
+        for name,value in self.hparams.items():
+            print(f'{name:24s} : {value}')
+
+        # ---- Generator/Discriminator instantiation
+        #
+        # self.generator     = Generator(latent_dim=self.hparams.latent_dim, img_shape=data_shape)
+        # self.discriminator = Discriminator(img_shape=data_shape)
+
+        print('Submodels :')
+        module=sys.modules['__main__']
+        class_g = getattr(module, generator_class)
+        class_d = getattr(module, discriminator_class)
+        self.generator     = class_g( latent_dim=latent_dim, data_shape=data_shape)
+        self.discriminator = class_d( latent_dim=latent_dim, data_shape=data_shape)
+
+        # ---- Validation and example data
+        #
+        self.validation_z        = torch.randn(8, self.hparams.latent_dim)
+        self.example_input_array = torch.zeros(2, self.hparams.latent_dim)
+
+
+    def forward(self, z):
+        return self.generator(z)
+
+
+    def adversarial_loss(self, y_hat, y):
+        return F.binary_cross_entropy(y_hat, y)
+
+
+
+# ------------------------------------------------------------------------------------ TO DO -------------------
+
+    # see : # from : https://github.com/rosshemsley/gander/blob/main/gander/models/gan.py
+
+    def gradient_penalty(self, real_images, fake_images):
+
+        batch_size = real_images.size(0)
+
+        # ---- Create interpolate images
+        #
+        # Get a random vector : size=([batch_size])
+        epsilon = torch.distributions.uniform.Uniform(0, 1).sample([batch_size])
+        # Add dimensions to match images batch : size=([batch_size,1,1,1])
+        epsilon = epsilon[:, None, None, None]
+        # Put epsilon a the right place
+        epsilon = epsilon.type_as(real_images)
+        # Do interpolation
+        interpolates = epsilon * fake_images + ((1 - epsilon) * real_images)
+
+        # ---- Use autograd to compute gradient
+        #
+        # The key to making this work is including `create_graph`, this means that the computations
+        # in this penalty will be added to the computation graph for the loss function, so that the
+        # second partial derivatives will be correctly computed.
+        #
+        interpolates.requires_grad = True
+
+        pred_labels = self.discriminator.forward(interpolates)
+
+        gradients = torch.autograd.grad(  inputs       = interpolates,
+                                          outputs      = pred_labels, 
+                                          grad_outputs = torch.ones_like(pred_labels),
+                                          create_graph = True, 
+                                          only_inputs  = True )[0]
+
+        grad_flat   = gradients.view(batch_size, -1)
+        grad_norm   = torch.linalg.norm(grad_flat, dim=1)
+
+        grad_penalty = (grad_norm - 1) ** 2 
+
+        return grad_penalty
+
+
+
+# ------------------------------------------------------------------------------------------------------------------
+
+
+    def training_step(self, batch, batch_idx, optimizer_idx):
+
+        real_imgs  = batch
+        batch_size = batch.size(0)
+        lambda_gp  = self.hparams.lambda_gp
+
+        # ---- Get some latent space vectors and fake images
+        #      We use type_as() to make sure we initialize z on the right device (GPU/CPU).
+        #
+        z = torch.randn(batch_size, self.hparams.latent_dim)
+        z = z.type_as(real_imgs)
+        
+        fake_imgs = self.generator.forward(z)
+
+        # ---- Train generator
+        #      Generator use optimizer #0
+        #      We try to generate false images that could have nive critics
+        #
+        if optimizer_idx == 0:
+
+            # Get critics
+            critics = self.discriminator.forward(fake_imgs)
+
+            # Loss
+            g_loss = -critics.mean()
+
+            # Log
+            self.log("g_loss", g_loss, prog_bar=True)
+
+            return g_loss
+
+        # ---- Train discriminator
+        #      Discriminator use optimizer #1
+        #      We try to make the difference between fake images and real ones 
+        #
+        if optimizer_idx == 1:
+            
+            # Get critics
+            critics_real = self.discriminator.forward(real_imgs)
+            critics_fake = self.discriminator.forward(fake_imgs)
+
+            # Get gradient penalty
+            grad_penalty = self.gradient_penalty(real_imgs, fake_imgs)
+
+            # Loss
+            d_loss = critics_fake.mean() - critics_real.mean() + lambda_gp*grad_penalty.mean()
+
+            # Log loss
+            self.log("d_loss", d_loss, prog_bar=True)
+
+            return d_loss
+
+
+    def configure_optimizers(self):
+
+        lr = self.hparams.lr
+        b1 = self.hparams.b1
+        b2 = self.hparams.b2
+
+        # With a GAN, we need 2 separate optimizer.
+        # opt_g to optimize the generator      #0
+        # opt_d to optimize the discriminator  #1
+        # opt_g = torch.optim.Adam(self.generator.parameters(),     lr=lr, betas=(b1, b2))
+        # opt_d = torch.optim.Adam(self.discriminator.parameters(), lr=lr, betas=(b1, b2),)
+        opt_g = torch.optim.Adam(self.generator.parameters(),     lr=lr)
+        opt_d = torch.optim.Adam(self.discriminator.parameters(), lr=lr)
+        return [opt_g, opt_d], []
+
+
+    def training_epoch_end(self, outputs):
+
+        # Get our validation latent vectors as z
+        # z = self.validation_z.type_as(self.generator.model[0].weight)
+
+        # ---- Log Graph
+        #
+        if(self.current_epoch==1):
+            sampleImg=torch.rand((1,28,28,1))
+            sampleImg=sampleImg.type_as(self.generator.model[0].weight)
+            self.logger.experiment.add_graph(self.discriminator,sampleImg)
+
+        # ---- Log d_loss/epoch
+        #
+        g_loss, d_loss = 0,0
+        for metrics in outputs:
+            g_loss+=float( metrics[0]['loss'] )
+            d_loss+=float( metrics[1]['loss'] )
+        g_loss, d_loss = g_loss/len(outputs), d_loss/len(outputs)
+        self.logger.experiment.add_scalar("g_loss/epochs",g_loss, self.current_epoch)
+        self.logger.experiment.add_scalar("d_loss/epochs",d_loss, self.current_epoch)
+
+        # ---- Log some of these images
+        #
+        z = torch.randn(self.hparams.batch_size, self.hparams.latent_dim)
+        z = z.type_as(self.generator.model[0].weight)
+        sample_imgs = self.generator(z)
+        sample_imgs = sample_imgs.permute(0, 3, 1, 2) # from NHWC to NCHW
+        grid = torchvision.utils.make_grid(tensor=sample_imgs, nrow=12, )
+        self.logger.experiment.add_image(f"Generated images", grid,self.current_epoch)
-- 
GitLab