Skip to content
Snippets Groups Projects
Generators.py 2.95 KiB
Newer Older

# ------------------------------------------------------------------
#     _____ _     _ _
#    |  ___(_) __| | | ___
#    | |_  | |/ _` | |/ _ \
#    |  _| | | (_| | |  __/
#    |_|   |_|\__,_|_|\___|                         GAN / Generators
# ------------------------------------------------------------------
# Formation Introduction au Deep Learning  (FIDLE)
# CNRS/MIAI - https://fidle.cnrs.fr
# ------------------------------------------------------------------
# JL Parouty (Mars 2024)



import numpy as np
import torch.nn as nn


# -----------------------------------------------------------------------------
# -- Generator n°1
# -----------------------------------------------------------------------------
#
class Generator_1(nn.Module):

    def __init__(self, latent_dim=None, data_shape=None):
        super().__init__()
        self.latent_dim = latent_dim
        self.img_shape  = data_shape
        print('init generator 1         : ',latent_dim,' to ',data_shape)

        self.model = nn.Sequential(
            
            nn.Linear(latent_dim, 128),
            nn.ReLU(),

            nn.Linear(128,256),
            nn.BatchNorm1d(256, 0.8),
            nn.ReLU(),

            nn.Linear(256, 512),
            nn.BatchNorm1d(512, 0.8),
            nn.ReLU(),

            nn.Linear(512, 1024),
            nn.BatchNorm1d(1024, 0.8),
            nn.ReLU(),

            nn.Linear(1024, int(np.prod(data_shape))),
            nn.Sigmoid()

        )

    def forward(self, z):
        img = self.model(z)
        img = img.view(img.size(0), *self.img_shape)
        return img


# -----------------------------------------------------------------------------
# -- Generator n°1
# -----------------------------------------------------------------------------
#
class Generator_2(nn.Module):

    def __init__(self, latent_dim=None, data_shape=None):
        super().__init__()
        self.latent_dim = latent_dim
        self.img_shape  = data_shape
        print('init generator 2         : ',latent_dim,' to ',data_shape)

        self.model = nn.Sequential(
            
            nn.Linear(latent_dim, 7*7*64),
            nn.Unflatten(1, (64,7,7)),
            
            # nn.UpsamplingNearest2d( scale_factor=2 ),
            nn.UpsamplingBilinear2d( scale_factor=2 ),
            nn.Conv2d( 64,128, (3,3), stride=(1,1), padding=(1,1) ),
            nn.ReLU(),
            nn.BatchNorm2d(128),

            # nn.UpsamplingNearest2d( scale_factor=2 ),
            nn.UpsamplingBilinear2d( scale_factor=2 ),
            nn.Conv2d( 128,256, (3,3), stride=(1,1), padding=(1,1)),
            nn.ReLU(),
            nn.BatchNorm2d(256),

            nn.Conv2d( 256,1, (5,5), stride=(1,1), padding=(2,2)),
            nn.Sigmoid()

        )

    def forward(self, z):
        img_nchw = self.model(z)
        img_nhwc = img_nchw.permute(0, 2, 3, 1) # reformat from NCHW to NHWC
        # img = img.view(img.size(0), *self.img_shape) # reformat from NCHW to NHWC
        return img_nhwc