Skip to content
Snippets Groups Projects
WGANGP.py 7.79 KiB
Newer Older

# ------------------------------------------------------------------
#     _____ _     _ _
#    |  ___(_) __| | | ___
#    | |_  | |/ _` | |/ _ \
#    |  _| | | (_| | |  __/
#    |_|   |_|\__,_|_|\___|                   WGANGP LigthningModule
# ------------------------------------------------------------------
# Formation Introduction au Deep Learning  (FIDLE)
# CNRS/MIAI - https://fidle.cnrs.fr
# ------------------------------------------------------------------
# JL Parouty (Mars 2024)



import sys
import numpy as np
import torch
import torch.nn.functional as F
import torchvision
from lightning import LightningModule


class WGANGP(LightningModule):

    # -------------------------------------------------------------------------
    # Init
    # -------------------------------------------------------------------------
    #
    def __init__(
        self,
        data_shape          = (None,None,None),
        latent_dim          = None,
        lr                  = 0.0002,
        b1                  = 0.5,
        b2                  = 0.999,
        batch_size          = 64,
        lambda_gp           = 10,
        generator_name      = None,
        discriminator_name  = None,
        **kwargs,
    ):
        super().__init__()

        print('\n---- GAN initialization --------------------------------------------')

        # ---- Hyperparameters
        #
        # Enable Lightning to store all the provided arguments under the self.hparams attribute.
        # These hyperparameters will also be stored within the model checkpoint.
        #
        self.save_hyperparameters()

        print('Hyperarameters are :')
        for name,value in self.hparams.items():
            print(f'{name:24s} : {value}')

        # ---- Because we have more than one optimizer
        #
        self.automatic_optimization = False

        # ---- Generator/Discriminator instantiation
        #
        print('Submodels :')
        module=sys.modules['__main__']
        class_g = getattr(module, generator_name)
        class_d = getattr(module, discriminator_name)
        self.generator     = class_g( latent_dim=latent_dim, data_shape=data_shape)
        self.discriminator = class_d( latent_dim=latent_dim, data_shape=data_shape)

        # ---- Validation and example data
        #
        self.validation_z        = torch.randn(8, self.hparams.latent_dim)
        self.example_input_array = torch.zeros(2, self.hparams.latent_dim)


    def forward(self, z):
        return self.generator(z)


    def adversarial_loss(self, y_pred, y):
        return F.binary_cross_entropy(y_pred, y)



    def gradient_penalty(self, real_images, fake_images):

        # see: https://medium.com/dejunhuang/implementing-gan-and-wgan-in-pytorch-551099afde3c

        batch_size = real_images.size(0)

        # ---- Create interpolate images
        #
        # Get a random vector : size=([batch_size])
        epsilon = torch.distributions.uniform.Uniform(0, 1).sample([batch_size])
        
        # Add dimensions to match images batch : size=([batch_size,1,1,1])
        epsilon = epsilon[:, None, None, None]
        
        # Put epsilon a the right place
        epsilon = epsilon.type_as(real_images)
        
        # Do interpolation
        interpolates = epsilon * fake_images + ((1 - epsilon) * real_images)

        # ---- Use autograd to compute gradient
        #
        # The key to making this work is including `create_graph`, this means that the computations
        # in this penalty will be added to the computation graph for the loss function, so that the
        # second partial derivatives will be correctly computed.
        #
        interpolates.requires_grad_()

        pred_labels = self.discriminator.forward(interpolates)

        gradients = torch.autograd.grad(  inputs       = interpolates,
                                          outputs      = pred_labels, 
                                          grad_outputs = torch.ones_like(pred_labels),
                                          create_graph = True, 
                                          retain_graph = True,
                                          only_inputs  = True )[0]

        grad_flat   = gradients.view(batch_size, -1)
        grad_norm   = torch.linalg.norm(grad_flat, dim=1)

        grad_penalty = (grad_norm - 1) ** 2 

        # gp = torch.pow(grads.norm(2, dim=1) - 1, 2).mean()

        return grad_penalty



    def training_step(self, batch, batch_idx):
        real_imgs  = batch
        batch_size = batch.size(0)
        lambda_gp  = self.hparams.lambda_gp


        optimizer_g, optimizer_d = self.optimizers()

        # ---- Get some latent space vectors
        #      We use type_as() to make sure we initialize z on the right device (GPU/CPU).
        #
        z = torch.randn(batch_size, self.hparams.latent_dim)
        z = z.type_as(real_imgs)
        
        # ---- Train generator ------------------------------------------------
        #      Generator use optimizer #0
        #      We try to generate false images that could mislead the discriminator
        # ---------------------------------------------------------------------
        #
        self.toggle_optimizer(optimizer_g)
                
        # Get fake images
        fake_imgs = self.generator.forward(z)
        
        # Get critics
        critics   = self.discriminator.forward(fake_imgs)

        # Loss
        g_loss = -critics.mean()

        # Log
        self.log("g_loss", g_loss, prog_bar=True)

        # Backward loss
        self.manual_backward(g_loss)
        
        optimizer_g.step()
        optimizer_g.zero_grad()
        
        self.untoggle_optimizer(optimizer_g)

        # ---- Train discriminator --------------------------------------------
        #      Discriminator use optimizer #1
        #      We try to make the difference between fake images and real ones 
        # ---------------------------------------------------------------------
        #
        self.toggle_optimizer(optimizer_d)

        # Get critics
        critics_real = self.discriminator.forward(real_imgs)
        critics_fake = self.discriminator.forward(fake_imgs.detach())

        # Get gradient penalty
        grad_penalty = self.gradient_penalty(real_imgs, fake_imgs.detach())

        # Loss
        d_loss = critics_fake.mean() - critics_real.mean() + lambda_gp*grad_penalty.mean()

        # Log loss
        self.log("d_loss", d_loss, prog_bar=True)

        # Backward
        self.manual_backward(d_loss)
        
        optimizer_d.step()
        optimizer_d.zero_grad()
        
        self.untoggle_optimizer(optimizer_d)
 


    def configure_optimizers(self):

        lr = self.hparams.lr
        b1 = self.hparams.b1
        b2 = self.hparams.b2

        # With a GAN, we need 2 separate optimizer.
        # opt_g = torch.optim.Adam(self.generator.parameters(),     lr=lr, betas=(b1, b2))
        # opt_d = torch.optim.Adam(self.discriminator.parameters(), lr=lr, betas=(b1, b2),)
        opt_g = torch.optim.Adam(self.generator.parameters(),     lr=lr)
        opt_d = torch.optim.Adam(self.discriminator.parameters(), lr=lr)
        return [opt_g, opt_d], []


    def on_train_epoch_end(self):

        # ---- Log Graph
        #
        if(self.current_epoch==1):
            sampleImg=torch.rand((1,28,28,1))
            sampleImg=sampleImg.type_as(self.generator.model[0].weight)
            self.logger.experiment.add_graph(self.discriminator,sampleImg)

        # ---- Log some of these images
        #
        z = torch.randn(self.hparams.batch_size, self.hparams.latent_dim)
        z = z.type_as(self.generator.model[0].weight)
        sample_imgs = self.generator(z)
        sample_imgs = sample_imgs.permute(0, 3, 1, 2) # from NHWC to NCHW
        grid = torchvision.utils.make_grid(tensor=sample_imgs, nrow=12, )
        self.logger.experiment.add_image(f"Generated images", grid,self.current_epoch)