Newer
Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
# ------------------------------------------------------------------
# _____ _ _ _
# | ___(_) __| | | ___
# | |_ | |/ _` | |/ _ \
# | _| | | (_| | | __/
# |_| |_|\__,_|_|\___| GAN / Generators
# ------------------------------------------------------------------
# Formation Introduction au Deep Learning (FIDLE)
# CNRS/MIAI - https://fidle.cnrs.fr
# ------------------------------------------------------------------
# JL Parouty (Mars 2024)
import numpy as np
import torch.nn as nn
# -----------------------------------------------------------------------------
# -- Generator n°1
# -----------------------------------------------------------------------------
#
class Generator_1(nn.Module):
def __init__(self, latent_dim=None, data_shape=None):
super().__init__()
self.latent_dim = latent_dim
self.img_shape = data_shape
print('init generator 1 : ',latent_dim,' to ',data_shape)
self.model = nn.Sequential(
nn.Linear(latent_dim, 128),
nn.ReLU(),
nn.Linear(128,256),
nn.BatchNorm1d(256, 0.8),
nn.ReLU(),
nn.Linear(256, 512),
nn.BatchNorm1d(512, 0.8),
nn.ReLU(),
nn.Linear(512, 1024),
nn.BatchNorm1d(1024, 0.8),
nn.ReLU(),
nn.Linear(1024, int(np.prod(data_shape))),
nn.Sigmoid()
)
def forward(self, z):
img = self.model(z)
img = img.view(img.size(0), *self.img_shape)
return img
# -----------------------------------------------------------------------------
# -- Generator n°1
# -----------------------------------------------------------------------------
#
class Generator_2(nn.Module):
def __init__(self, latent_dim=None, data_shape=None):
super().__init__()
self.latent_dim = latent_dim
self.img_shape = data_shape
print('init generator 2 : ',latent_dim,' to ',data_shape)
self.model = nn.Sequential(
nn.Linear(latent_dim, 7*7*64),
nn.Unflatten(1, (64,7,7)),
# nn.UpsamplingNearest2d( scale_factor=2 ),
nn.UpsamplingBilinear2d( scale_factor=2 ),
nn.Conv2d( 64,128, (3,3), stride=(1,1), padding=(1,1) ),
nn.ReLU(),
nn.BatchNorm2d(128),
# nn.UpsamplingNearest2d( scale_factor=2 ),
nn.UpsamplingBilinear2d( scale_factor=2 ),
nn.Conv2d( 128,256, (3,3), stride=(1,1), padding=(1,1)),
nn.ReLU(),
nn.BatchNorm2d(256),
nn.Conv2d( 256,1, (5,5), stride=(1,1), padding=(2,2)),
nn.Sigmoid()
)
def forward(self, z):
img_nchw = self.model(z)
img_nhwc = img_nchw.permute(0, 2, 3, 1) # reformat from NCHW to NHWC
# img = img.view(img.size(0), *self.img_shape) # reformat from NCHW to NHWC
return img_nhwc