{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<img width=\"800px\" src=\"../fidle/img/00-Fidle-header-01.svg\"></img>\n",
    "\n",
    "# <!-- TITLE --> [VAE5] - Checking the clustered CelebA dataset\n",
    "<!-- DESC --> Verification of prepared data from CelebA dataset\n",
    "<!-- AUTHOR : Jean-Luc Parouty (CNRS/SIMaP) -->\n",
    "\n",
    "## Objectives :\n",
    " - Making sure our clustered dataset is correct\n",
    " - Do a little bit of python while waiting to build and train our VAE model.\n",
    "\n",
    "The [CelebFaces Attributes Dataset (CelebA)](http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html) contains about 200,000 images (202599,218,178,3).  \n",
    "\n",
    "\n",
    "## What we're going to do :\n",
    "\n",
    " - Reload our dataset\n",
    " - Check and verify our clustered dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 1 - Import and init\n",
    "### 1.2 - Import"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "\n",
    "import os,time,sys,json,glob,importlib\n",
    "import math, random\n",
    "\n",
    "import modules.data_generator\n",
    "from modules.data_generator import DataGenerator\n",
    "\n",
    "sys.path.append('..')\n",
    "import fidle.pwk as ooo\n",
    "\n",
    "ooo.init()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 1.2 - Directories and files :"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "place, dataset_dir = ooo.good_place( { 'GRICAD' : f'{os.getenv(\"SCRATCH_DIR\",\"\")}/PROJECTS/pr-fidle/datasets/celeba',\n",
    "                                       'IDRIS'  : f'{os.getenv(\"WORK\",\"\")}/datasets/celeba',\n",
    "                                       'HOME'   : f'{os.getenv(\"HOME\",\"\")}/datasets/celeba'} )\n",
    "\n",
    "train_dir    = f'{dataset_dir}/clusters.train'\n",
    "test_dir     = f'{dataset_dir}/clusters.test'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 2 - Data verification\n",
    "What we're going to do:\n",
    " - Recover all clusters by normalizing images\n",
    " - Make some statistics to be sure we have all the data\n",
    " - picking one image per cluster to check that everything is good."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# ---- Return a legend from a description \n",
    "def get_legend(x_desc,i):\n",
    "    cols  = x_desc.columns\n",
    "    desc  = x_desc.iloc[i]\n",
    "    legend =[]\n",
    "    for i,v in enumerate(desc):\n",
    "        if v==1 : legend.append(cols[i])\n",
    "    return str('\\n'.join(legend))\n",
    "\n",
    "start_time = time.time()\n",
    "\n",
    "# ---- get cluster list\n",
    "\n",
    "clusters_name = [ os.path.splitext(f)[0] for f in glob.glob( f'{train_dir}/*.npy') ]\n",
    "\n",
    "# ---- Counters set to 0\n",
    "\n",
    "imax  = len(clusters_name)\n",
    "i,n1,n2,s = 0,0,0,0\n",
    "imgs,desc = [],[]\n",
    "\n",
    "# ---- Reload all clusters\n",
    "\n",
    "ooo.update_progress('Load clusters :',i,imax, redraw=True)\n",
    "for cluster_name in clusters_name:  \n",
    "    \n",
    "    # ---- Reload images and normalize\n",
    "\n",
    "    x_data = np.load(cluster_name+'.npy')\n",
    "    \n",
    "    # ---- Reload descriptions\n",
    "    \n",
    "    x_desc = pd.read_csv(cluster_name+'.csv', header=0)\n",
    "    \n",
    "    # ---- Counters\n",
    "    \n",
    "    n1 += len(x_data)\n",
    "    n2 += len(x_desc.index)\n",
    "    s  += x_data.nbytes\n",
    "    i  += 1\n",
    "    \n",
    "    # ---- Get somes images/legends\n",
    "    \n",
    "    j=random.randint(0,len(x_data)-1)\n",
    "    imgs.append( x_data[j].copy() )\n",
    "    desc.append( get_legend(x_desc,j) )\n",
    "    x_data=None\n",
    "    \n",
    "    # ---- To appear professional\n",
    "    \n",
    "    ooo.update_progress('Load clusters :',i,imax, redraw=True)\n",
    "\n",
    "d=time.time()-start_time\n",
    "\n",
    "print(f'Loading time      : {d:.2f} s or {ooo.hdelay(d)}')\n",
    "print(f'Number of cluster : {i}')\n",
    "print(f'Number of images  : {n1}')\n",
    "print(f'Number of desc.   : {n2}')\n",
    "print(f'Total size of img : {ooo.hsize(s)}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ooo.plot_images(imgs,desc,x_size=2,y_size=2,fontsize=8,columns=7,y_padding=2.5)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<div class='nota'>\n",
    "    <b>Note :</b> With this approach, the use of data is much much more effective !\n",
    "    <ul>\n",
    "        <li>Data loading speed : <b>x 10</b> (81 s vs 16 min.)</li>\n",
    "    </ul>\n",
    "</div>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 3 - How we will read our data during the train session\n",
    "We are going to use a \"dataset reader\", which is a [tensorflow.keras.utils.Sequence](https://www.tensorflow.org/api_docs/python/tf/keras/utils/Sequence)  \n",
    "The batches will be requested to our DataGenerator, which will read the clusters as they come in.\n",
    "\n",
    "### 3.1 - An example to understand"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# ---- A very small dataset\n",
    "\n",
    "clusters_dir = f'{dataset_dir}/clusters-xs.train'\n",
    "\n",
    "# ---- Our DataGenerator\n",
    "#      with small batch size, debug mode and 50% of the dataset\n",
    "\n",
    "data_gen = DataGenerator(clusters_dir, 32, debug=True, k_size=1)\n",
    "\n",
    "# ---- We ask him to retrieve all batchs\n",
    "\n",
    "batch_sizes=[]\n",
    "for i in range( len(data_gen)):\n",
    "    x,y = data_gen[i]\n",
    "    batch_sizes.append(len(x))\n",
    "\n",
    "print(f'\\n\\ntotal number of items : {sum(batch_sizes)}')\n",
    "print(f'batch sizes      : {batch_sizes}')\n",
    "print(f'Last batch shape : {x.shape}')\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "---\n",
    "<img width=\"80px\" src=\"../fidle/img/00-Fidle-logo-01.svg\"></img>"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}