Skip to content
Snippets Groups Projects
Discriminators.py 4.27 KiB
Newer Older
# ------------------------------------------------------------------
#     _____ _     _ _
#    |  ___(_) __| | | ___
#    | |_  | |/ _` | |/ _ \
#    |  _| | | (_| | |  __/
#    |_|   |_|\__,_|_|\___|                         GAN / Generators
# ------------------------------------------------------------------
# Formation Introduction au Deep Learning  (FIDLE)
# CNRS/MIAI - https://fidle.cnrs.fr
# ------------------------------------------------------------------
# JL Parouty (Mars 2024)


import numpy as np
import torch.nn as nn


# -----------------------------------------------------------------------------
# -- Discriminator n°1
# -----------------------------------------------------------------------------
#
class Discriminator_1(nn.Module):
    '''
    A basic DNN discriminator, usable with classic GAN
    '''

    def __init__(self, latent_dim=None, data_shape=None):
    
        super().__init__()
        self.img_shape = data_shape
        print('init discriminator 1     : ',data_shape,' to sigmoid')

        self.model = nn.Sequential(

            nn.Flatten(),
            nn.Linear(int(np.prod(data_shape)), 512),
            nn.ReLU(),
            
            nn.Linear(512, 256),
            nn.ReLU(),

            nn.Linear(256, 1),
            nn.Sigmoid(),
        )

    def forward(self, img):
        validity = self.model(img)

        return validity



# -----------------------------------------------------------------------------
# -- Discriminator n°2
# -----------------------------------------------------------------------------
#
class Discriminator_2(nn.Module):
    '''
    A more efficient discriminator,based on CNN, usable with classic GAN
    '''

    def __init__(self, latent_dim=None, data_shape=None):
    
        super().__init__()
        self.img_shape = data_shape
        print('init discriminator 2     : ',data_shape,' to sigmoid')

        self.model = nn.Sequential(

            nn.Conv2d(1, 32, kernel_size = 3, stride = 2, padding = 1),
            nn.ReLU(),
            nn.BatchNorm2d(32),
            nn.Dropout2d(0.25),

            nn.Conv2d(32, 64, kernel_size = 3, stride = 1, padding = 1),
            nn.ReLU(),
            nn.BatchNorm2d(64),
            nn.Dropout2d(0.25),

            nn.Conv2d(64, 128, kernel_size = 3, stride = 1, padding = 1),
            nn.ReLU(),
            nn.BatchNorm2d(128),
            nn.Dropout2d(0.25),

            nn.Conv2d(128, 256, kernel_size = 3, stride = 2, padding = 1),
            nn.ReLU(),
            nn.BatchNorm2d(256),
            nn.Dropout2d(0.25),

            nn.Flatten(),
            nn.Linear(12544, 1),
            nn.Sigmoid(),
        )

    def forward(self, img):
        img_nchw = img.permute(0, 3, 1, 2) # reformat from NHWC to NCHW
        validity = self.model(img_nchw)

        return validity


   
# -----------------------------------------------------------------------------
# -- Discriminator n°3
# -----------------------------------------------------------------------------
#     
class Discriminator_3(nn.Module):
    '''
    A CNN discriminator, usable with a WGANGP.
    This discriminator has no sigmoid and returns a critical and not a probability
    '''

    def __init__(self, latent_dim=None, data_shape=None):
    
        super().__init__()
        self.img_shape = data_shape
        print('init discriminator 3     : ',data_shape,' to sigmoid')

        self.model = nn.Sequential(

            nn.Conv2d(1, 32, kernel_size = 3, stride = 2, padding = 1),
            nn.ReLU(),
            nn.BatchNorm2d(32),
            nn.Dropout2d(0.25),

            nn.Conv2d(32, 64, kernel_size = 3, stride = 1, padding = 1),
            nn.ReLU(),
            nn.BatchNorm2d(64),
            nn.Dropout2d(0.25),

            nn.Conv2d(64, 128, kernel_size = 3, stride = 1, padding = 1),
            nn.ReLU(),
            nn.BatchNorm2d(128),
            nn.Dropout2d(0.25),

            nn.Conv2d(128, 256, kernel_size = 3, stride = 2, padding = 1),
            nn.ReLU(),
            nn.BatchNorm2d(256),
            nn.Dropout2d(0.25),

            nn.Flatten(),
            nn.Linear(12544, 1),
            nn.Sigmoid(),

        )

    def forward(self, img):
        img_nchw = img.permute(0, 3, 1, 2) # reformat from NHWC to NCHW
        validity = self.model(img_nchw)

        return validity