Newer
Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
# ------------------------------------------------------------------
# _____ _ _ _
# | ___(_) __| | | ___
# | |_ | |/ _` | |/ _ \
# | _| | | (_| | | __/
# |_| |_|\__,_|_|\___| VAE Example
# ------------------------------------------------------------------
# ------------------------------------------------------------------
# Formation Introduction au Deep Learning (FIDLE)
# CNRS/MIAI - https://fidle.cnrs.fr
# ------------------------------------------------------------------
# JL Parouty (mars 2024
import numpy as np
import keras
import torch
from IPython.display import display,Markdown
from modules.layers import SamplingLayer
import os
# Note : https://keras.io/guides/making_new_layers_and_models_via_subclassing/
class VAE(keras.Model):
'''
A VAE model, built from given encoder and decoder
'''
version = '2.0'
def __init__(self, encoder=None, decoder=None, loss_weights=[1,1], **kwargs):
'''
VAE instantiation with encoder, decoder and r_loss_factor
args :
encoder : Encoder model
decoder : Decoder model
loss_weights : Weight of the loss functions: reconstruction_loss and kl_loss
r_loss_factor : Proportion of reconstruction loss for global loss (0.3)
return:
None
'''
super(VAE, self).__init__(**kwargs)
self.encoder = encoder
self.decoder = decoder
self.loss_weights = loss_weights
print(f'Fidle VAE is ready :-) loss_weights={list(self.loss_weights)}')
def call(self, inputs):
'''
Model forward pass, when we use our model
args:
inputs : Model inputs
return:
output : Output of the model
'''
z_mean, z_log_var, z = self.encoder(inputs)
output = self.decoder(z)
return output
def train_step(self, input):
'''
Implementation of the training update.
Receive an input, compute loss, get gradient, update weights and return metrics.
Here, our metrics are loss.
args:
inputs : Model inputs
return:
loss : Total loss
r_loss : Reconstruction loss
kl_loss : KL loss
'''
# ---- Get the input we need, specified in the .fit()
#
if isinstance(input, tuple):
input = input[0]
k1,k2 = self.loss_weights
# ---- Reset grad
#
self.zero_grad()
# ---- Forward pass
#
# Get encoder outputs
#
z_mean, z_log_var, z = self.encoder(input)
# ---- Get reconstruction from decoder
#
reconstruction = self.decoder(z)
# ---- Compute loss
# Total loss = Reconstruction loss + KL loss
#
r_loss = torch.nn.functional.binary_cross_entropy(reconstruction, input, reduction='sum')
kl_loss = - torch.sum(1+ z_log_var - z_mean.pow(2) - z_log_var.exp())
loss = r_loss*k1 + kl_loss*k2
# ---- Compute gradients for the weights
#
loss.backward()
# ---- Adjust learning weights
#
trainable_weights = [v for v in self.trainable_weights]
gradients = [v.value.grad for v in trainable_weights]
with torch.no_grad():
self.optimizer.apply(gradients, trainable_weights)
# ---- Update metrics (includes the metric that tracks the loss)
#
for metric in self.metrics:
if metric.name == "loss":
metric.update_state(loss)
else:
metric.update_state(input, reconstruction)
# ---- Return a dict mapping metric names to current value
# Note that it will include the loss (tracked in self.metrics).
#
return {m.name: m.result() for m in self.metrics}
# # ---- Forward pass
# # Run the forward pass and record
# # operations on the GradientTape.
# #
# with tf.GradientTape() as tape:
# # ---- Get encoder outputs
# #
# z_mean, z_log_var, z = self.encoder(input)
# # ---- Get reconstruction from decoder
# #
# reconstruction = self.decoder(z)
# # ---- Compute loss
# # Reconstruction loss, KL loss and Total loss
# #
# reconstruction_loss = k1 * tf.reduce_mean( keras.losses.binary_crossentropy(input, reconstruction) )
# kl_loss = 1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var)
# kl_loss = -tf.reduce_mean(kl_loss) * k2
# total_loss = reconstruction_loss + kl_loss
# # ---- Retrieve gradients from gradient_tape
# # and run one step of gradient descent
# # to optimize trainable weights
# #
# grads = tape.gradient(total_loss, self.trainable_weights)
# self.optimizer.apply_gradients(zip(grads, self.trainable_weights))
# return {
# "loss": total_loss,
# "r_loss": reconstruction_loss,
# "kl_loss": kl_loss,
# }
def predict(self,inputs):
'''Our predict function...'''
z_mean, z_var, z = self.encoder.predict(inputs)
outputs = self.decoder.predict(z)
return outputs
def save(self,filename):
'''Save model in 2 part'''
filename, extension = os.path.splitext(filename)
self.encoder.save(f'{filename}-encoder.keras')
self.decoder.save(f'{filename}-decoder.keras')
def reload(self,filename):
'''Reload a 2 part saved model.'''
filename, extension = os.path.splitext(filename)
self.encoder = keras.models.load_model(f'{filename}-encoder.keras', custom_objects={'SamplingLayer': SamplingLayer})
self.decoder = keras.models.load_model(f'{filename}-decoder.keras')
print('Reloaded.')
@classmethod
def about(cls):
'''Basic whoami method'''
display(Markdown('<br>**FIDLE 2024 - VAE**'))
print('Version :', cls.version)
print('Keras version :', keras.__version__)