{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<img width=\"800px\" src=\"../fidle/img/header.svg\"></img>\n",
    "\n",
    "# <!-- TITLE --> [LVAE3] - Analysis of the VAE's latent space of MNIST dataset\n",
    "<!-- DESC --> Visualization and analysis of the VAE's latent space of the dataset MNIST, using PyTorch Lightning\n",
    "<!-- AUTHOR : Jean-Luc Parouty (CNRS/SIMaP) -->\n",
    "\n",
    "## Objectives :\n",
    " - First data generation from **latent space** \n",
    " - Understanding of underlying principles\n",
    " - Model management\n",
    "\n",
    "Here, we don't consume data anymore, but we generate them ! ;-)\n",
    "\n",
    "## What we're going to do :\n",
    "\n",
    " - Load a saved model\n",
    " - Reconstruct some images\n",
    " - Latent space visualization\n",
    " - Matrix of generated images\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 1 - Init python stuff"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 1.1 - Init python"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "import torch\n",
    "import pandas as pd\n",
    "import numpy  as np\n",
    "import torch.nn as nn\n",
    "\n",
    "from modules.callbacks import ImagesCallback, BestModelCallback\n",
    "from modules.datagen   import MNIST\n",
    "from modules.models    import Encoder, Decoder, VAE \n",
    "\n",
    "\n",
    "import scipy.stats\n",
    "import matplotlib\n",
    "import matplotlib.pyplot as plt\n",
    "from barviz import Simplex\n",
    "from barviz import Collection\n",
    "\n",
    "\n",
    "import fidle\n",
    "\n",
    "# Init Fidle environment\n",
    "run_id, run_dir, datasets_dir = fidle.init('LVAE3')\n",
    "\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 1.2 - Parameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "scale      = 1\n",
    "seed       = 123\n",
    "models_dir = './run/models_dir/best-model-epoch=4-loss=0.00.ckpt'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Override parameters (batch mode) - Just forget this cell"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fidle.override('scale', 'seed', 'models_dir')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 2 - Get data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "x_data, y_data, _,_ = MNIST.get_data(seed=seed, scale=scale, train_prop=1 )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 3 - Reload best model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#---- Load the model from a checkpoint\n",
    "latent_dim=6\n",
    "\n",
    "vae = VAE.load_from_checkpoint(models_dir,\n",
    "                               encoder=Encoder(latent_dim=latent_dim),\n",
    "                               decoder=Decoder(latent_dim=latent_dim)\n",
    "                              )\n",
    "# put model in evaluation mode\n",
    "vae.eval()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 4 - Image reconstruction"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# ---- Select few images\n",
    "\n",
    "x_show = fidle.utils.pick_dataset(x_data, n=10)\n",
    "\n",
    "# ---- Get latent points and reconstructed images\n",
    "\n",
    "z_mean, z_var, z  = vae.encoder(x_show.to(device))\n",
    "x_reconst         = vae.decoder(z)\n",
    "\n",
    "latent_dim        = z.shape[1]\n",
    "\n",
    "# ---- Show it\n",
    "z         = z.cpu().detach()         # Move the tensor to CPU and detach it\n",
    "x_reconst = x_reconst.cpu().detach()\n",
    "\n",
    "labels=[ str(np.round(z[i],1)) for i in range(10) ]\n",
    "fidle.utils.subtitle('Originals :')\n",
    "fidle.scrawler.images(x_show,    None, indices='all', columns=10, x_size=2,y_size=2, save_as='01-original')\n",
    "fidle.utils.subtitle('Reconstructed :')\n",
    "fidle.scrawler.images(x_reconst, None, indices='all', columns=10, x_size=2,y_size=2, save_as='02-reconstruct')\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 5 - Visualizing the latent space"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_show = 5000\n",
    "\n",
    "# ---- Select images\n",
    "\n",
    "x_show, y_show   = fidle.utils.pick_dataset(x_data,y_data, n=n_show)\n",
    "\n",
    "# ---- Get latent points\n",
    "\n",
    "z_mean, z_var, z = vae.encoder(x_show.to(device))\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 5.1 - Classic 2d visualisaton"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "z   = z.cpu().detach()\n",
    "fig = plt.figure(figsize=(14, 10))\n",
    "plt.scatter(z[:, 0] , z[:, 1], c=y_show, cmap= 'tab10', alpha=0.5, s=30)\n",
    "plt.colorbar()\n",
    "fidle.scrawler.save_fig('03-Latent-space')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 5.2 - Simplex visualisaton"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if latent_dim<4:\n",
    "\n",
    "    print('Sorry, This part can only work if the latent space is greater than 3')\n",
    "\n",
    "else:\n",
    "\n",
    "    # ---- Softmax rescale\n",
    "    #\n",
    "    zs = torch.exp(z)/torch.sum(torch.exp(z),axis=1,keepdims=True)\n",
    "    zs=zs.cpu().detach()\n",
    "    # zc  = zs * 1/np.max(zs)\n",
    "\n",
    "    # ---- Create collection\n",
    "    #\n",
    "    c = Collection(zs, colors=y_show, labels=y_show)\n",
    "    c.attrs.markers_colormap     = {'colorscale':'Rainbow','cmin':0,'cmax':latent_dim}\n",
    "    c.attrs.markers_size         = 5\n",
    "    c.attrs.markers_border_width = 0\n",
    "    c.attrs.markers_opacity      = 0.8\n",
    "\n",
    "    s = Simplex.build(latent_dim)\n",
    "    s.attrs.width  = 1000\n",
    "    s.attrs.height = 1000\n",
    "    s.plot(c)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 6 - Generate from latent space (latent_dim==2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if latent_dim>2:\n",
    "\n",
    "    print('Sorry, This part can only work if the latent space is of dimension 2')\n",
    "\n",
    "else:\n",
    "\n",
    "    grid_size   = 14\n",
    "    grid_scale  = 1.\n",
    "\n",
    "    # ---- Draw a ppf grid\n",
    "\n",
    "    grid=[]\n",
    "    for y in scipy.stats.norm.ppf(np.linspace(0.99, 0.01, grid_size),scale=grid_scale):\n",
    "        for x in scipy.stats.norm.ppf(np.linspace(0.01, 0.99, grid_size),scale=grid_scale):\n",
    "            grid.append( (x,y) )\n",
    "    grid=np.array(grid)\n",
    "\n",
    "    # ---- Draw latentspoints and grid\n",
    "\n",
    "    fig = plt.figure(figsize=(12, 10))\n",
    "    plt.scatter(z[:, 0] , z[:, 1], c=y_show, cmap= 'tab10', alpha=0.5, s=20)\n",
    "    plt.scatter(grid[:, 0] , grid[:, 1], c = 'black', s=60, linewidth=2, marker='+', alpha=1)\n",
    "    fidle.scrawler.save_fig('04-Latent-grid')\n",
    "    plt.show()\n",
    "\n",
    "    # ---- Plot grid corresponding images\n",
    "    grid      = torch.from_numpy(grid).to(device)\n",
    "    x_reconst = vae.decoder([grid])\n",
    "    x_reconst = x_reconst.cpu().detach()\n",
    "    fidle.scrawler.images(x_reconst, indices='all', columns=grid_size, x_size=0.5,y_size=0.5, y_padding=0,spines_alpha=0.1, save_as='05-Latent-morphing')\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fidle.end()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "---\n",
    "<img width=\"80px\" src=\"../fidle/img/logo-paysage.svg\"></img>"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  },
  "vscode": {
   "interpreter": {
    "hash": "b3929042cc22c1274d74e3e946c52b845b57cb6d84f2d591ffe0519b38e4896d"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}