{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "<img width=\"800px\" src=\"../fidle/img/00-Fidle-header-01.svg\"></img>\n", "\n", "# <!-- TITLE --> [VAE3] - Analysis of the VAE's latent space of MNIST dataset\n", "<!-- DESC --> Visualization and analysis of the VAE's latent space of the dataset MNIST\n", "<!-- AUTHOR : Jean-Luc Parouty (CNRS/SIMaP) -->\n", "\n", "## Objectives :\n", " - First data generation from **latent space** \n", " - Understanding of underlying principles\n", " - Model management\n", "\n", "Here, we don't consume data anymore, but we generate them ! ;-)\n", "\n", "## What we're going to do :\n", "\n", " - Load a saved model\n", " - Reconstruct some images\n", " - Latent space visualization\n", " - Matrix of generated images\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Step 1 - Init python stuff" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 1.1 - Init python" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import tensorflow as tf\n", "from tensorflow import keras\n", "\n", "from modules.models import VAE\n", "from modules.datagen import MNIST\n", "\n", "import scipy.stats\n", "import matplotlib\n", "import matplotlib.pyplot as plt\n", "\n", "import sys\n", "sys.path.append('..')\n", "import fidle.pwk as pwk\n", "\n", "run_dir = './run/VAE2.001'\n", "datasets_dir = pwk.init('VAE3', run_dir)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 1.2 - Parameters" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "scale = 1\n", "seed = 123" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Override parameters (batch mode) - Just forget this cell" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "pwk.override('scale', 'seed')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Step 2 - Get data" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "x_data, y_data, _,_ = MNIST.get_data(seed=seed, scale=scale, train_prop=1 )" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Step 3 - Reload best model" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "vae=VAE()\n", "vae.reload(f'{run_dir}/models/best_model')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Step 4 - Image reconstruction" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# ---- Select few images\n", "\n", "x_show = pwk.pick_dataset(x_data, n=10)\n", "\n", "# ---- Get latent points and reconstructed images\n", "\n", "z_mean, z_var, z = vae.encoder.predict(x_show)\n", "x_reconst = vae.decoder.predict(z)\n", "\n", "# ---- Show it\n", "\n", "labels=[ str(np.round(z[i],1)) for i in range(10) ]\n", "pwk.plot_images(x_show, labels, indices='all', columns=10, x_size=2,y_size=2, save_as='01-original')\n", "pwk.plot_images(x_reconst, None , indices='all', columns=10, x_size=2,y_size=2, save_as='02-reconstruct')\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Step 5 - Visualizing the latent space" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "n_show = 20000\n", "\n", "# ---- Select images\n", "\n", "x_show, y_show = pwk.pick_dataset(x_data,y_data, n=n_show)\n", "\n", "# ---- Get latent points\n", "\n", "z_mean, z_var, z = vae.encoder.predict(x_show)\n", "\n", "# ---- Show them\n", "\n", "fig = plt.figure(figsize=(14, 10))\n", "plt.scatter(z[:, 0] , z[:, 1], c=y_show, cmap= 'tab10', alpha=0.5, s=30)\n", "plt.colorbar()\n", "pwk.save_fig('03-Latent-space')\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Step 6 - Generate from latent space" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "grid_size = 14\n", "grid_scale = 1.\n", "\n", "# ---- Draw a ppf grid\n", "\n", "grid=[]\n", "for y in scipy.stats.norm.ppf(np.linspace(0.99, 0.01, grid_size),scale=grid_scale):\n", " for x in scipy.stats.norm.ppf(np.linspace(0.01, 0.99, grid_size),scale=grid_scale):\n", " grid.append( (x,y) )\n", "grid=np.array(grid)\n", "\n", "# ---- Draw latentspoints and grid\n", "\n", "fig = plt.figure(figsize=(12, 10))\n", "plt.scatter(z[:, 0] , z[:, 1], c=y_show, cmap= 'tab10', alpha=0.5, s=20)\n", "plt.scatter(grid[:, 0] , grid[:, 1], c = 'black', s=60, linewidth=2, marker='+', alpha=1)\n", "pwk.save_fig('04-Latent-grid')\n", "plt.show()\n", "\n", "# ---- Plot grid corresponding images\n", "\n", "x_reconst = vae.decoder.predict([grid])\n", "pwk.plot_images(x_reconst, indices='all', columns=grid_size, x_size=0.5,y_size=0.5, y_padding=0,spines_alpha=0.1, save_as='05-Latent-morphing')\n", "\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "pwk.end()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "---\n", "<img width=\"80px\" src=\"../fidle/img/00-Fidle-logo-01.svg\"></img>" ] } ], "metadata": { "interpreter": { "hash": "8e38643e33497db9a306e3f311fa98cb1e65371278ca73ee4ea0c76aa5a4f387" }, "kernelspec": { "display_name": "Python 3.9.7 64-bit ('fidle-cpu': conda)", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.7" } }, "nbformat": 4, "nbformat_minor": 4 }