Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • daconcea/fidle
  • bossardl/fidle
  • Julie.Remenant/fidle
  • abijolao/fidle
  • monsimau/fidle
  • karkars/fidle
  • guilgautier/fidle
  • cailletr/fidle
  • talks/fidle
9 results
Show changes
Showing
with 6308 additions and 582 deletions
# ------------------------------------------------------------------
# _____ _ _ _
# | ___(_) __| | | ___
# | |_ | |/ _` | |/ _ \
# | _| | | (_| | | __/
# |_| |_|\__,_|_|\___| GAN / Generators
# ------------------------------------------------------------------
# Formation Introduction au Deep Learning (FIDLE)
# CNRS/MIAI - https://fidle.cnrs.fr
# ------------------------------------------------------------------
# JL Parouty (Mars 2024)
import numpy as np
import torch.nn as nn
# -----------------------------------------------------------------------------
# -- Generator n°1
# -----------------------------------------------------------------------------
#
class Generator_1(nn.Module):
def __init__(self, latent_dim=None, data_shape=None):
super().__init__()
self.latent_dim = latent_dim
self.img_shape = data_shape
print('init generator 1 : ',latent_dim,' to ',data_shape)
self.model = nn.Sequential(
nn.Linear(latent_dim, 128),
nn.ReLU(),
nn.Linear(128,256),
nn.BatchNorm1d(256, 0.8),
nn.ReLU(),
nn.Linear(256, 512),
nn.BatchNorm1d(512, 0.8),
nn.ReLU(),
nn.Linear(512, 1024),
nn.BatchNorm1d(1024, 0.8),
nn.ReLU(),
nn.Linear(1024, int(np.prod(data_shape))),
nn.Sigmoid()
)
def forward(self, z):
img = self.model(z)
img = img.view(img.size(0), *self.img_shape)
return img
# -----------------------------------------------------------------------------
# -- Generator n°1
# -----------------------------------------------------------------------------
#
class Generator_2(nn.Module):
def __init__(self, latent_dim=None, data_shape=None):
super().__init__()
self.latent_dim = latent_dim
self.img_shape = data_shape
print('init generator 2 : ',latent_dim,' to ',data_shape)
self.model = nn.Sequential(
nn.Linear(latent_dim, 7*7*64),
nn.Unflatten(1, (64,7,7)),
# nn.UpsamplingNearest2d( scale_factor=2 ),
nn.UpsamplingBilinear2d( scale_factor=2 ),
nn.Conv2d( 64,128, (3,3), stride=(1,1), padding=(1,1) ),
nn.ReLU(),
nn.BatchNorm2d(128),
# nn.UpsamplingNearest2d( scale_factor=2 ),
nn.UpsamplingBilinear2d( scale_factor=2 ),
nn.Conv2d( 128,256, (3,3), stride=(1,1), padding=(1,1)),
nn.ReLU(),
nn.BatchNorm2d(256),
nn.Conv2d( 256,1, (5,5), stride=(1,1), padding=(2,2)),
nn.Sigmoid()
)
def forward(self, z):
img_nchw = self.model(z)
img_nhwc = img_nchw.permute(0, 2, 3, 1) # reformat from NCHW to NHWC
# img = img.view(img.size(0), *self.img_shape) # reformat from NCHW to NHWC
return img_nhwc
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
image diff could not be displayed: it is too large. Options to address this: view the blob.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.