Skip to content
Snippets Groups Projects
03-AE-with-MNIST-post.ipynb 8.67 KiB
Newer Older
{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<img width=\"800px\" src=\"../fidle/img/header.svg\"></img>\n",
    "\n",
    "# <!-- TITLE --> [K3AE3] - Playing with our denoiser model\n",
    "<!-- DESC --> Episode 2 : Using the previously trained autoencoder to denoise data\n",
    "<!-- AUTHOR : Jean-Luc Parouty (CNRS/SIMaP) -->\n",
    "\n",
    "## Objectives :\n",
    " - Retrieve and use our denoiser model\n",
    "\n",
    "\n",
    "## What we're going to do :\n",
    "\n",
    " - Reload our dataset and saved best model\n",
    " - Encode/decode some test images (neved used, never seen by the model)\n",
    " \n",
    "## Data Terminology :\n",
    "- `clean_train`, `clean_test` for noiseless images \n",
    "- `noisy_train`, `noisy_test` for noisy images\n",
    "- `denoised_test` for denoised images at the output of the model\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 1 - Init python stuff\n",
    "### 1.1 - Init"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ['KERAS_BACKEND'] = 'torch'\n",
    "\n",
    "import keras\n",
    "\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import random\n",
    "\n",
    "from modules.MNIST import MNIST\n",
    "\n",
    "import fidle\n",
    "\n",
    "# Init Fidle environment\n",
    "run_id, run_dir, datasets_dir = fidle.init('K3AE3')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 1.2 - Parameters\n",
    "These **parameters must be identical** to those used during the training in order to have the **same dataset**.\\\n",
    "`prepared_dataset` : Filename of the prepared dataset (Need 400 Mo, but can be in ./data)  \n",
    "`dataset_seed` : Random seed for shuffling dataset  \n",
    "`scale` : % of the dataset to use (1. for 100%)  \n",
    "`train_prop` : Percentage for train (the rest being for the test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "prepared_dataset = './data/mnist-noisy.h5'\n",
    "saved_models     = './run/K3AE2/models'\n",
    "dataset_seed     = 123\n",
    "scale            = 1\n",
    "train_prop       = .8"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Override parameters (batch mode) - Just forget this cell"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fidle.override('prepared_dataset', 'dataset_seed', 'scale', 'train_prop')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 2 - Retrieve dataset\n",
    "With our MNIST class, in one call, we can reload, rescale, shuffle and split our previously saved dataset :-)  \n",
    "**Important :** Make sure that the **digest is identical** to the one used during the training !\\\n",
    "See : [AE2 / Step 2 - Retrieve dataset](./02-AE-with-MNIST.ipynb#Step-2---Retrieve-dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "clean_train,clean_test, noisy_train,noisy_test, _,_ = MNIST.reload_prepared_dataset(scale      = scale, \n",
    "                                                                                    train_prop = train_prop,\n",
    "                                                                                    seed       = dataset_seed,\n",
    "                                                                                    shuffle    = True,\n",
    "                                                                                    filename=prepared_dataset )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 3 - Evaluation\n",
    "**Note :** We will use the following data:\\\n",
    "`clean_train`, `clean_test` for noiseless images \\\n",
    "`noisy_train`, `noisy_test` for noisy images\\\n",
    "`denoised_test` for denoised images at the output of the model\n",
    " \n",
    "### 3.1 - Reload our best model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# model = keras.models.load_model(f'{saved_models}/model.keras')\n",
    "\n",
    "encoder = keras.models.load_model(f'{saved_models}/encoder.keras')\n",
    "decoder = keras.models.load_model(f'{saved_models}/decoder.keras')\n",
    "\n",
    "inputs    = keras.Input(shape=(28, 28, 1))\n",
    "\n",
    "latents   = encoder(inputs)\n",
    "outputs   = decoder(latents)\n",
    "\n",
    "model = keras.Model(inputs,outputs, name=\"ae\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 3.2 - Let's make a prediction"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tabnanny import verbose\n",
    "\n",
    "\n",
    "denoised_test = model.predict(noisy_test,verbose=0)\n",
    "\n",
    "print('Denoised images   (denoised_test) shape : ',denoised_test.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 3.3 - Denoised images "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "i=random.randint(0,len(denoised_test)-8)\n",
    "j=i+8\n",
    "\n",
    "fidle.utils.subtitle('Noisy test images (input):')\n",
    "fidle.scrawler.images(noisy_test[i:j], None, indices='all', columns=8, x_size=2,y_size=2, interpolation=None, save_as='05-test-noisy')\n",
    "\n",
    "fidle.utils.subtitle('Denoised images (output):')\n",
    "fidle.scrawler.images(denoised_test[i:j], None, indices='all', columns=8, x_size=2,y_size=2, interpolation=None, save_as='06-test-predict')\n",
    "\n",
    "fidle.utils.subtitle('Real test images :')\n",
    "fidle.scrawler.images(clean_test[i:j], None, indices='all', columns=8, x_size=2,y_size=2, interpolation=None, save_as='07-test-real')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 4 - Looking at the latent space\n",
    "### 4.1 - Getting clean data and class"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "clean_data,_, _,_, class_data,_ = MNIST.reload_prepared_dataset(scale      = 1, \n",
    "                                                                train_prop = 1,\n",
    "                                                                seed       = dataset_seed,\n",
    "                                                                shuffle    = False,\n",
    "                                                                filename   = prepared_dataset )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 4.2 - Retrieve encoder"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "encoder=model.get_layer('encoder')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 4.3 Showing latent space\n",
    "Here is the digit distribution in the latent space"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_show = 20000\n",
    "\n",
    "# ---- Select images\n",
    "\n",
    "x_show, y_show = fidle.utils.pick_dataset(clean_data, class_data, n=n_show)\n",
    "\n",
    "# ---- Get latent points\n",
    "\n",
    "z = encoder.predict(x_show)\n",
    "\n",
    "# ---- Show them\n",
    "\n",
    "fig = plt.figure(figsize=(14, 10))\n",
    "plt.scatter(z[:, 0] , z[:, 1], c=y_show, cmap= 'tab10', alpha=0.5, s=30)\n",
    "plt.colorbar()\n",
    "fidle.scrawler.save_fig('08-Latent-space')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fidle.end()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "---\n",
    "<img width=\"80px\" src=\"../fidle/img/logo-paysage.svg\"></img>"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.9.2 ('fidle-env')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.2"
  },
  "vscode": {
   "interpreter": {
    "hash": "b3929042cc22c1274d74e3e946c52b845b57cb6d84f2d591ffe0519b38e4896d"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}