{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<img width=\"800px\" src=\"../fidle/img/header.svg\"></img>\n",
    "\n",
    "# <!-- TITLE --> [K3AE4] - Denoiser and classifier model\n",
    "<!-- DESC --> Episode 4 : Construction of a denoiser and classifier model\n",
    "\n",
    "<!-- AUTHOR : Jean-Luc Parouty (CNRS/SIMaP) -->\n",
    "\n",
    "## Objectives :\n",
    " - Building a multiple output model, able to **denoise** and **classify**\n",
    " - Understanding a more **advanced programming model**\n",
    "\n",
    "The calculation needs being important, it is preferable to use a very simple dataset such as MNIST.  \n",
    "The use of a GPU is often indispensable.\n",
    "\n",
    "## What we're going to do :\n",
    "\n",
    " - Defining a multiple output model using Keras procedural programing model\n",
    " - Build the model\n",
    " - Train it\n",
    " - Follow the learning process\n",
    " \n",
    "## Data Terminology :\n",
    "- `clean_train`, `clean_test` for noiseless images \n",
    "- `noisy_train`, `noisy_test` for noisy images\n",
    "- `class_train`, `class_test` for the classes to which the images belong \n",
    "- `denoised_test` for denoised images at the output of the model\n",
    "- `classcat_test` for class prediction in model output (is a softmax)\n",
    "- `classid_test` class prediction (ie: argmax of classcat_test)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 1 - Init python stuff\n",
    "### 1.1 - Init"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ['KERAS_BACKEND'] = 'torch'\n",
    "\n",
    "import keras\n",
    "\n",
    "import numpy as np\n",
    "from skimage import io\n",
    "import random\n",
    "\n",
    "from modules.AE4_builder    import AE4_builder\n",
    "from modules.MNIST          import MNIST\n",
    "from modules.ImagesCallback import ImagesCallback\n",
    "\n",
    "import fidle\n",
    "\n",
    "# Init Fidle environment\n",
    "run_id, run_dir, datasets_dir = fidle.init('K3AE4')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 1.2 - Parameters\n",
    "`prepared_dataset` : Filename of the prepared dataset (Need 400 Mo, but can be in ./data)  \n",
    "`dataset_seed` : Random seed for shuffling dataset. 'None' mean using /dev/urandom  \n",
    "`scale` : % of the dataset to use (1. for 100%)  \n",
    "`latent_dim` : Dimension of the latent space  \n",
    "`train_prop` : Percentage for train (the rest being for the test)\n",
    "`batch_size` : Batch size  \n",
    "`epochs` : Nb of epochs for training\\\n",
    "`fit_verbosity` is the verbosity during training : 0 = silent, 1 = progress bar, 2 = one line per epoch\n",
    "\n",
    "scale=0.1, epochs=20  => 2' on a laptop\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "prepared_dataset = './data/mnist-noisy.h5'\n",
    "dataset_seed     = None\n",
    "\n",
    "scale            = .1\n",
    "\n",
    "train_prop       = .8\n",
    "batch_size       = 128\n",
    "epochs           = 10\n",
    "fit_verbosity    = 1"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Override parameters (batch mode) - Just forget this cell"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fidle.override('prepared_dataset', 'dataset_seed', 'scale')\n",
    "fidle.override('train_prop', 'batch_size', 'epochs', 'fit_verbosity')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 2 - Retrieve dataset\n",
    "With our MNIST class, in one call, we can reload, rescale, shuffle and split our previously saved dataset :-)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "clean_train,clean_test, noisy_train,noisy_test, class_train,class_test = MNIST.reload_prepared_dataset(\n",
    "                                                                                    scale      = scale, \n",
    "                                                                                    train_prop = train_prop,\n",
    "                                                                                    seed       = dataset_seed,\n",
    "                                                                                    shuffle    = True,\n",
    "                                                                                    filename   = prepared_dataset )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 3 - Build models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "builder = AE4_builder( ae={ 'latent_dim':10 }, cnn = { 'lc1':8, 'lc2':16, 'ld':100 } )\n",
    "\n",
    "model = builder.create_model()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.compile(optimizer='rmsprop', \n",
    "              loss={'ae':'binary_crossentropy', 'classifier':'sparse_categorical_crossentropy'},\n",
    "              loss_weights={'ae':1., 'classifier':1.},\n",
    "              metrics={'classifier':'accuracy'} )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# keras.utils.plot_model(model, \"multi_input_and_output_model.png\", show_shapes=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 4 - Train\n",
    "20' on a CPU  \n",
    "1'12 on a GPU (V100, IDRIS)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# ---- Callback : Images\n",
    "#\n",
    "fidle.utils.mkdir( run_dir + '/images')\n",
    "filename = run_dir + '/images/image-{epoch:03d}-{i:02d}.jpg'\n",
    "\n",
    "encoder = model.get_layer('ae').get_layer('encoder')\n",
    "decoder = model.get_layer('ae').get_layer('decoder')\n",
    "\n",
    "callback_images = ImagesCallback(filename, x=clean_test[:5], encoder=encoder,decoder=decoder)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "chrono = fidle.Chrono()\n",
    "chrono.start()\n",
    "\n",
    "history = model.fit(noisy_train, [clean_train, class_train],\n",
    "                 batch_size      = batch_size,\n",
    "                 epochs          = epochs,\n",
    "                 verbose         = fit_verbosity,\n",
    "                 validation_data = (noisy_test, [clean_test, class_test]),\n",
    "                 callbacks       = [ callback_images ]  )\n",
    "\n",
    "chrono.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Save model weights"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "os.makedirs(f'{run_dir}/models', exist_ok=True)\n",
    "\n",
    "model.save_weights(f'{run_dir}/models/model.weights.h5')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 5 - History"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fidle.scrawler.history(history,  plot={'Loss':['loss', 'val_loss'],\n",
    "                                 'Accuracy':['classifier_accuracy','val_classifier_accuracy']}, save_as='01-history')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 6 - Denoising progress"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "imgs=[]\n",
    "for epoch in range(0,epochs,4):\n",
    "    for i in range(5):\n",
    "        filename = run_dir + '/images/image-{epoch:03d}-{i:02d}.jpg'.format(epoch=epoch, i=i)\n",
    "        img      = io.imread(filename)\n",
    "        imgs.append(img)      \n",
    "\n",
    "fidle.utils.subtitle('Real images (clean_test) :')\n",
    "fidle.scrawler.images(clean_test[:5], None, indices='all', columns=5, x_size=2,y_size=2, interpolation=None, save_as='02-original-real')\n",
    "\n",
    "fidle.utils.subtitle('Noisy images (noisy_test) :')\n",
    "fidle.scrawler.images(noisy_test[:5], None, indices='all', columns=5, x_size=2,y_size=2, interpolation=None, save_as='03-original-noisy')\n",
    "\n",
    "fidle.utils.subtitle('Evolution during the training period (denoised_test) :')\n",
    "fidle.scrawler.images(imgs, None, indices='all', columns=5, x_size=2,y_size=2, interpolation=None, y_padding=0.1, save_as='04-learning')\n",
    "\n",
    "fidle.utils.subtitle('Noisy images (noisy_test) :')\n",
    "fidle.scrawler.images(noisy_test[:5], None, indices='all', columns=5, x_size=2,y_size=2, interpolation=None, save_as=None)\n",
    "\n",
    "fidle.utils.subtitle('Real images (clean_test) :')\n",
    "fidle.scrawler.images(clean_test[:5], None, indices='all', columns=5, x_size=2,y_size=2, interpolation=None, save_as=None)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 7 - Evaluation\n",
    "**Note :** We will use the following data:\\\n",
    "`clean_train`, `clean_test` for noiseless images \\\n",
    "`noisy_train`, `noisy_test` for noisy images\\\n",
    "`class_train`, `class_test` for the classes to which the images belong \\\n",
    "`denoised_test` for denoised images at the output of the model\\\n",
    "`classcat_test` for class prediction in model output (is a softmax)\\\n",
    "`classid_test` class prediction (ie: argmax of classcat_test)\n",
    " \n",
    "### 7.1 - Reload our model (weights)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "builder = AE4_builder( ae={ 'latent_dim':10 }, cnn = { 'lc1':8, 'lc2':16, 'ld':100 } )\n",
    "\n",
    "model = builder.create_model()\n",
    "\n",
    "model.load_weights(f'{run_dir}/models/model.weights.h5')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 7.2 - Let's make a prediction\n",
    "Note that our model will returns 2 outputs : **denoised images** from output 1 and **class prediction** from output 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "outputs = model.predict(noisy_test, verbose=0)\n",
    "\n",
    "denoised = outputs['ae']\n",
    "classcat = outputs['classifier']\n",
    "\n",
    "print('Denoised images   (denoised_test) shape : ', denoised.shape)\n",
    "print('Predicted classes (classcat_test) shape : ', classcat.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 7.3 - Denoised images "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "i=random.randint(0,len(denoised)-8)\n",
    "j=i+8\n",
    "\n",
    "fidle.utils.subtitle('Noisy test images (input):')\n",
    "fidle.scrawler.images(noisy_test[i:j], None, indices='all', columns=8, x_size=2,y_size=2, interpolation=None, save_as='05-test-noisy')\n",
    "\n",
    "fidle.utils.subtitle('Denoised images (output):')\n",
    "fidle.scrawler.images(denoised[i:j], None, indices='all', columns=8, x_size=2,y_size=2, interpolation=None, save_as='06-test-predict')\n",
    "\n",
    "fidle.utils.subtitle('Real test images :')\n",
    "fidle.scrawler.images(clean_test[i:j], None, indices='all', columns=8, x_size=2,y_size=2, interpolation=None, save_as='07-test-real')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 7.4 - Class prediction\n",
    "Note: The evaluation requires the noisy images as input (noisy_test) and the 2 expected outputs:\n",
    " - the images without noise (clean_test)\n",
    " - the classes (class_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# We need to (re)compile our resurrected model (to specify loss and metrics)\n",
    "#\n",
    "model.compile(optimizer='rmsprop', \n",
    "              loss={'ae':'binary_crossentropy', 'classifier':'sparse_categorical_crossentropy'},\n",
    "              loss_weights={'ae':1., 'classifier':1.},\n",
    "              metrics={'classifier':'accuracy'} )\n",
    "\n",
    "\n",
    "# Get an evaluation\n",
    "#\n",
    "score = model.evaluate(noisy_test, [clean_test, class_test], verbose=0)\n",
    "\n",
    "# And show results\n",
    "#\n",
    "fidle.utils.subtitle(\"Accuracy :\")\n",
    "print(f'Classification accuracy : {score[1]:4.4f}')\n",
    "\n",
    "fidle.utils.subtitle(\"Few examples :\")\n",
    "classid_test  = np.argmax(classcat, axis=-1)\n",
    "fidle.scrawler.images(noisy_test, class_test, range(0,200), columns=12, x_size=1, y_size=1, y_pred=classid_test, save_as='04-predictions')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fidle.end()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "---\n",
    "<img width=\"80px\" src=\"../fidle/img/logo-paysage.svg\"></img>"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.9.2 ('fidle-env')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.2"
  },
  "vscode": {
   "interpreter": {
    "hash": "b3929042cc22c1274d74e3e946c52b845b57cb6d84f2d591ffe0519b38e4896d"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}