import numpy as np import tensorflow as tf from tensorflow import keras from tensorflow.keras import layers # See : https://keras.io/api/models/model/ # See :https://keras.io/guides/customizing_what_happens_in_fit/ class AE(keras.Model): def __init__(self, encoder=None, decoder=None, **kwargs): super(AE, self).__init__(**kwargs) self.encoder = encoder self.decoder = decoder def call(self, inputs): z = self.encoder(inputs) y_pred = self.decoder(z) return y_pred def train_step(self, data): x, y = data with tf.GradientTape() as tape: z = self.encoder(x) y_pred = self.decoder(z) # Compute the loss value loss = self.compiled_loss(y, y_pred, regularization_losses=self.losses) # ---- Compute gradients # trainable_vars = self.trainable_variables gradients = tape.gradient(loss, trainable_vars) # ---- Update weights # self.optimizer.apply_gradients(zip(gradients, trainable_vars)) # ---- Update metrics (includes the metric that tracks the loss) # self.compiled_metrics.update_state(y, y_pred) # ---- Return a dict mapping metric names to current value # return {m.name: m.result() for m in self.metrics} # return {"loss":loss} def reload(self,filename): self.encoder = keras.models.load_model(f'{filename}-enc.h5') self.decoder = keras.models.load_model(f'{filename}-dec.h5') def save(self,filename): self.encoder.save(f'{filename}-enc.h5') self.decoder.save(f'{filename}-dec.h5')