{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<img width=\"800px\" src=\"../fidle/img/00-Fidle-header-01.svg\"></img>\n",
    "\n",
    "# <!-- TITLE --> [IMDB1] - Sentiment alalysis with text embedding\n",
    "<!-- DESC --> A very classical example of word embedding with a dataset from Internet Movie Database (IMDB)\n",
    "<!-- AUTHOR : Jean-Luc Parouty (CNRS/SIMaP) -->\n",
    "\n",
    "## Objectives :\n",
    " - The objective is to guess whether film reviews are **positive or negative** based on the analysis of the text. \n",
    " - Understand the management of **textual data** and **sentiment analysis**\n",
    "\n",
    "Original dataset can be find **[there](http://ai.stanford.edu/~amaas/data/sentiment/)**  \n",
    "Note that [IMDb.com](https://imdb.com) offers several easy-to-use [datasets](https://www.imdb.com/interfaces/)  \n",
    "For simplicity's sake, we'll use the dataset directly [embedded in Keras](https://www.tensorflow.org/api_docs/python/tf/keras/datasets)\n",
    "\n",
    "## What we're going to do :\n",
    "\n",
    " - Retrieve data\n",
    " - Preparing the data\n",
    " - Build a model\n",
    " - Train the model\n",
    " - Evaluate the result\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 1 - Init python stuff"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<style>\n",
       "\n",
       "div.warn {    \n",
       "    background-color: #fcf2f2;\n",
       "    border-color: #dFb5b4;\n",
       "    border-left: 5px solid #dfb5b4;\n",
       "    padding: 0.5em;\n",
       "    font-weight: bold;\n",
       "    font-size: 1.1em;;\n",
       "    }\n",
       "\n",
       "\n",
       "\n",
       "div.nota {    \n",
       "    background-color: #DAFFDE;\n",
       "    border-left: 5px solid #92CC99;\n",
       "    padding: 0.5em;\n",
       "    }\n",
       "\n",
       "div.todo:before { content:url();\n",
       "    float:left;\n",
       "    margin-right:20px;\n",
       "    margin-top:-20px;\n",
       "    margin-bottom:20px;\n",
       "}\n",
       "div.todo{\n",
       "    font-weight: bold;\n",
       "    font-size: 1.1em;\n",
       "    margin-top:40px;\n",
       "}\n",
       "div.todo ul{\n",
       "    margin: 0.2em;\n",
       "}\n",
       "div.todo li{\n",
       "    margin-left:60px;\n",
       "    margin-top:0;\n",
       "    margin-bottom:0;\n",
       "}\n",
       "\n",
       "div .comment{\n",
       "    font-size:0.8em;\n",
       "    color:#696969;\n",
       "}\n",
       "\n",
       "\n",
       "\n",
       "</style>\n",
       "\n"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/markdown": [
       "**FIDLE 2020 - Practical Work Module**"
      ],
      "text/plain": [
       "<IPython.core.display.Markdown object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Version              : 0.6.1 DEV\n",
      "Notebook id          : IMDB1\n",
      "Run time             : Friday 18 December 2020, 17:52:06\n",
      "TensorFlow version   : 2.0.0\n",
      "Keras version        : 2.2.4-tf\n",
      "Datasets dir         : /home/pjluc/datasets/fidle\n",
      "Running mode         : full\n",
      "Update keras cache   : False\n",
      "Save figs            : True\n",
      "Path figs            : ./run/figs\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "\n",
    "import tensorflow as tf\n",
    "import tensorflow.keras as keras\n",
    "import tensorflow.keras.datasets.imdb as imdb\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib\n",
    "\n",
    "import os,sys,h5py,json\n",
    "from importlib import reload\n",
    "\n",
    "sys.path.append('..')\n",
    "import fidle.pwk as pwk\n",
    "\n",
    "datasets_dir = pwk.init('IMDB1')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 2 - Retrieve data\n",
    "\n",
    "IMDb dataset can bet get directly from Keras - see [documentation](https://www.tensorflow.org/api_docs/python/tf/keras/datasets)  \n",
    "Note : Due to their nature, textual data can be somewhat complex.\n",
    "\n",
    "### 2.1 - Data structure :  \n",
    "The dataset is composed of 2 parts: \n",
    "\n",
    " - **reviews**, this will be our **x**\n",
    " - **opinions** (positive/negative), this will be our **y**\n",
    "\n",
    "There are also a **dictionary**, because words are indexed in reviews\n",
    "\n",
    "```\n",
    "<dataset> = (<reviews>, <opinions>)\n",
    "\n",
    "with :  <reviews>  = [ <review1>, <review2>, ... ]\n",
    "        <opinions> = [ <rate1>,   <rate2>,   ... ]   where <ratei>   = integer\n",
    "\n",
    "where : <reviewi> = [ <w1>, <w2>, ...]    <wi> are the index (int) of the word in the dictionary\n",
    "        <ratei>   = int                   0 for negative opinion, 1 for positive\n",
    "\n",
    "\n",
    "<dictionary> = [ <word1>:<w1>, <word2>:<w2>, ... ]\n",
    "\n",
    "with :  <wordi>   = word\n",
    "        <wi>      = int\n",
    "\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 2.2 - Get dataset\n",
    "For simplicity, we will use a pre-formatted dataset - See [documentation](https://www.tensorflow.org/api_docs/python/tf/keras/datasets/imdb/load_data)  \n",
    "However, Keras offers some usefull tools for formatting textual data - See [documentation](https://www.tensorflow.org/api_docs/python/tf/keras/preprocessing/text)  \n",
    "\n",
    "**Load dataset :**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/pjluc/anaconda3/envs/fidle/lib/python3.7/site-packages/tensorflow_core/python/keras/datasets/imdb.py:129: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this, you must specify 'dtype=object' when creating the ndarray\n",
      "  x_train, y_train = np.array(xs[:idx]), np.array(labels[:idx])\n",
      "/home/pjluc/anaconda3/envs/fidle/lib/python3.7/site-packages/tensorflow_core/python/keras/datasets/imdb.py:130: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this, you must specify 'dtype=object' when creating the ndarray\n",
      "  x_test, y_test = np.array(xs[idx:]), np.array(labels[idx:])\n"
     ]
    }
   ],
   "source": [
    "vocab_size = 10000\n",
    "\n",
    "# ----- Retrieve x,y\n",
    "\n",
    "# Uncomment this if you want to load dataset directly from keras (small size <20M)\n",
    "#\n",
    "(x_train, y_train), (x_test, y_test) = imdb.load_data( num_words  = vocab_size,\n",
    "                                                       skip_top   = 0,\n",
    "                                                       maxlen     = None,\n",
    "                                                       seed       = 42,\n",
    "                                                       start_char = 1,\n",
    "                                                       oov_char   = 2,\n",
    "                                                       index_from = 3, )\n",
    "\n",
    "# To load a h5 version of the dataset :\n",
    "#\n",
    "# with  h5py.File(f'{datasets_dir}/IMDB/origine/dataset_imdb.h5','r') as f:\n",
    "#        x_train = f['x_train'][:]\n",
    "#        y_train = f['y_train'][:]\n",
    "#        x_test  = f['x_test'][:]\n",
    "#        y_test  = f['y_test'][:]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**About this dataset :**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  Max(x_train,x_test)  :  9999\n",
      "  x_train : (25000,)  y_train : (25000,)\n",
      "  x_test  : (25000,)  y_test  : (25000,)\n",
      "\n",
      "Review example (x_train[12]) :\n",
      "\n",
      " [1, 14, 22, 1367, 53, 206, 159, 4, 636, 898, 74, 26, 11, 436, 363, 108, 7, 14, 432, 14, 22, 9, 1055, 34, 8599, 2, 5, 381, 3705, 4509, 14, 768, 47, 839, 25, 111, 1517, 2579, 1991, 438, 2663, 587, 4, 280, 725, 6, 58, 11, 2714, 201, 4, 206, 16, 702, 5, 5176, 19, 480, 5920, 157, 13, 64, 219, 4, 2, 11, 107, 665, 1212, 39, 4, 206, 4, 65, 410, 16, 565, 5, 24, 43, 343, 17, 5602, 8, 169, 101, 85, 206, 108, 8, 3008, 14, 25, 215, 168, 18, 6, 2579, 1991, 438, 2, 11, 129, 1609, 36, 26, 66, 290, 3303, 46, 5, 633, 115, 4363]\n"
     ]
    }
   ],
   "source": [
    "print(\"  Max(x_train,x_test)  : \", pwk.rmax([x_train,x_test]) )\n",
    "print(\"  x_train : {}  y_train : {}\".format(x_train.shape, y_train.shape))\n",
    "print(\"  x_test  : {}  y_test  : {}\".format(x_test.shape,  y_test.shape))\n",
    "\n",
    "print('\\nReview example (x_train[12]) :\\n\\n',x_train[12])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 2.3 - Have a look for humans (optional)\n",
    "When we loaded the dataset, we asked for using \\<start\\> as 1, \\<unknown word\\> as 2  \n",
    "So, we shifted the dataset by 3 with the parameter index_from=3\n",
    "\n",
    "**Load dictionary :**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# ---- Retrieve dictionary {word:index}, and encode it in ascii\n",
    "#\n",
    "word_index = imdb.get_word_index()\n",
    "\n",
    "# ---- Shift the dictionary from +3\n",
    "#\n",
    "word_index = {w:(i+3) for w,i in word_index.items()}\n",
    "\n",
    "# ---- Add <pad>, <start> and unknown tags\n",
    "#\n",
    "word_index.update( {'<pad>':0, '<start>':1, '<unknown>':2} )\n",
    "\n",
    "# ---- Create a reverse dictionary : {index:word}\n",
    "#\n",
    "index_word = {index:word for word,index in word_index.items()} \n",
    "\n",
    "# ---- Add a nice function to transpose :\n",
    "#\n",
    "def dataset2text(review):\n",
    "    return ' '.join([index_word.get(i, '?') for i in review])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Have a look :**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Dictionary size     :  88587\n",
      "440 : hope\n",
      "441 : entertaining\n",
      "442 : she's\n",
      "443 : mr\n",
      "444 : overall\n",
      "445 : evil\n",
      "446 : called\n",
      "447 : loved\n",
      "448 : based\n",
      "449 : oh\n",
      "450 : several\n",
      "451 : fans\n",
      "452 : mother\n",
      "453 : drama\n",
      "454 : beginning\n"
     ]
    },
    {
     "data": {
      "text/markdown": [
       "<br>**Review example :**"
      ],
      "text/plain": [
       "<IPython.core.display.Markdown object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[1, 14, 22, 1367, 53, 206, 159, 4, 636, 898, 74, 26, 11, 436, 363, 108, 7, 14, 432, 14, 22, 9, 1055, 34, 8599, 2, 5, 381, 3705, 4509, 14, 768, 47, 839, 25, 111, 1517, 2579, 1991, 438, 2663, 587, 4, 280, 725, 6, 58, 11, 2714, 201, 4, 206, 16, 702, 5, 5176, 19, 480, 5920, 157, 13, 64, 219, 4, 2, 11, 107, 665, 1212, 39, 4, 206, 4, 65, 410, 16, 565, 5, 24, 43, 343, 17, 5602, 8, 169, 101, 85, 206, 108, 8, 3008, 14, 25, 215, 168, 18, 6, 2579, 1991, 438, 2, 11, 129, 1609, 36, 26, 66, 290, 3303, 46, 5, 633, 115, 4363]\n"
     ]
    },
    {
     "data": {
      "text/markdown": [
       "<br>**After translation :**"
      ],
      "text/plain": [
       "<IPython.core.display.Markdown object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<start> this film contains more action before the opening credits than are in entire hollywood films of this sort this film is produced by tsui <unknown> and stars jet li this team has brought you many worthy hong kong cinema productions including the once upon a time in china series the action was fast and furious with amazing wire work i only saw the <unknown> in two shots aside from the action the story itself was strong and not just used as filler to find any other action films to rival this you must look for a hong kong cinema <unknown> in your area they are really worth checking out and usually never disappoint\n"
     ]
    }
   ],
   "source": [
    "print('\\nDictionary size     : ', len(word_index))\n",
    "for k in range(440,455):print(f'{k:2d} : {index_word[k]}' )\n",
    "pwk.subtitle('Review example :')\n",
    "print(x_train[12])\n",
    "pwk.subtitle('After translation :')\n",
    "print(dataset2text(x_train[12]))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 2.4 - Have a look for NN"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div class=\"comment\">Saved: ./run/figs/IMDB1-01-stats-sizes</div>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "image/png": "\n",
      "text/plain": [
       "<Figure size 1152x432 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "sizes=[len(i) for i in x_train]\n",
    "plt.figure(figsize=(16,6))\n",
    "plt.hist(sizes, bins=400)\n",
    "plt.gca().set(title='Distribution of reviews by size - [{:5.2f}, {:5.2f}]'.format(min(sizes),max(sizes)), \n",
    "              xlabel='Size', ylabel='Density', xlim=[0,1500])\n",
    "pwk.save_fig('01-stats-sizes')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 3 - Preprocess the data (padding)\n",
    "In order to be processed by an NN, all entries must have the **same length.**  \n",
    "We chose a review length of **review_len**  \n",
    "We will therefore complete them with a padding (of \\<pad\\>\\)  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "<br>**After padding :**"
      ],
      "text/plain": [
       "<IPython.core.display.Markdown object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[   1   14   22 1367   53  206  159    4  636  898   74   26   11  436\n",
      "  363  108    7   14  432   14   22    9 1055   34 8599    2    5  381\n",
      " 3705 4509   14  768   47  839   25  111 1517 2579 1991  438 2663  587\n",
      "    4  280  725    6   58   11 2714  201    4  206   16  702    5 5176\n",
      "   19  480 5920  157   13   64  219    4    2   11  107  665 1212   39\n",
      "    4  206    4   65  410   16  565    5   24   43  343   17 5602    8\n",
      "  169  101   85  206  108    8 3008   14   25  215  168   18    6 2579\n",
      " 1991  438    2   11  129 1609   36   26   66  290 3303   46    5  633\n",
      "  115 4363    0    0    0    0    0    0    0    0    0    0    0    0\n",
      "    0    0    0    0    0    0    0    0    0    0    0    0    0    0\n",
      "    0    0    0    0    0    0    0    0    0    0    0    0    0    0\n",
      "    0    0    0    0    0    0    0    0    0    0    0    0    0    0\n",
      "    0    0    0    0    0    0    0    0    0    0    0    0    0    0\n",
      "    0    0    0    0    0    0    0    0    0    0    0    0    0    0\n",
      "    0    0    0    0    0    0    0    0    0    0    0    0    0    0\n",
      "    0    0    0    0    0    0    0    0    0    0    0    0    0    0\n",
      "    0    0    0    0    0    0    0    0    0    0    0    0    0    0\n",
      "    0    0    0    0    0    0    0    0    0    0    0    0    0    0\n",
      "    0    0    0    0]\n"
     ]
    },
    {
     "data": {
      "text/markdown": [
       "<br>**In real words :**"
      ],
      "text/plain": [
       "<IPython.core.display.Markdown object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<start> this film contains more action before the opening credits than are in entire hollywood films of this sort this film is produced by tsui <unknown> and stars jet li this team has brought you many worthy hong kong cinema productions including the once upon a time in china series the action was fast and furious with amazing wire work i only saw the <unknown> in two shots aside from the action the story itself was strong and not just used as filler to find any other action films to rival this you must look for a hong kong cinema <unknown> in your area they are really worth checking out and usually never disappoint <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad>\n"
     ]
    }
   ],
   "source": [
    "review_len = 256\n",
    "\n",
    "x_train = keras.preprocessing.sequence.pad_sequences(x_train,\n",
    "                                                     value   = 0,\n",
    "                                                     padding = 'post',\n",
    "                                                     maxlen  = review_len)\n",
    "\n",
    "x_test  = keras.preprocessing.sequence.pad_sequences(x_test,\n",
    "                                                     value   = 0 ,\n",
    "                                                     padding = 'post',\n",
    "                                                     maxlen  = review_len)\n",
    "\n",
    "pwk.subtitle('After padding :')\n",
    "print(x_train[12])\n",
    "pwk.subtitle('In real words :')\n",
    "print(dataset2text(x_train[12]))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Save dataset and dictionary (For future use but not mandatory)**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Saved.\n"
     ]
    }
   ],
   "source": [
    "# ---- Write dataset in a h5 file, could be usefull\n",
    "#\n",
    "output_dir = './data'\n",
    "pwk.mkdir(output_dir)\n",
    "\n",
    "with h5py.File(f'{output_dir}/dataset_imdb.h5', 'w') as f:\n",
    "    f.create_dataset(\"x_train\",    data=x_train)\n",
    "    f.create_dataset(\"y_train\",    data=y_train)\n",
    "    f.create_dataset(\"x_test\",     data=x_test)\n",
    "    f.create_dataset(\"y_test\",     data=y_test)\n",
    "\n",
    "with open(f'{output_dir}/word_index.json', 'w') as fp:\n",
    "    json.dump(word_index, fp)\n",
    "\n",
    "with open(f'{output_dir}/index_word.json', 'w') as fp:\n",
    "    json.dump(index_word, fp)\n",
    "\n",
    "print('Saved.')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 4 - Build the model\n",
    "Few remarks :\n",
    " - We'll choose a dense vector size for the embedding output with **dense_vector_size**\n",
    " - **GlobalAveragePooling1D** do a pooling on the last dimension : (None, lx, ly) -> (None, ly)  \n",
    "   In other words: we average the set of vectors/words of a sentence\n",
    " - L'embedding de Keras fonctionne de manière supervisée. Il s'agit d'une couche de *vocab_size* neurones vers *n_neurons* permettant de maintenir une table de vecteurs (les poids constituent les vecteurs). Cette couche ne calcule pas de sortie a la façon des couches normales, mais renvois la valeur des vecteurs. n mots => n vecteurs (ensuite empilés par le pooling)  \n",
    "Voir : [Explication plus détaillée (en)](https://stats.stackexchange.com/questions/324992/how-the-embedding-layer-is-trained-in-keras-embedding-layer)  \n",
    "ainsi que : [Sentiment detection with Keras](https://www.liip.ch/en/blog/sentiment-detection-with-keras-word-embeddings-and-lstm-deep-learning-networks)  \n",
    "\n",
    "More documentation about this model functions :\n",
    " - [Embedding](https://www.tensorflow.org/api_docs/python/tf/keras/layers/Embedding)\n",
    " - [GlobalAveragePooling1D](https://www.tensorflow.org/api_docs/python/tf/keras/layers/GlobalAveragePooling1D)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_model(dense_vector_size=32):\n",
    "    \n",
    "    model = keras.Sequential()\n",
    "    model.add(keras.layers.Embedding(input_dim    = vocab_size, \n",
    "                                     output_dim   = dense_vector_size, \n",
    "                                     input_length = review_len))\n",
    "    model.add(keras.layers.GlobalAveragePooling1D())\n",
    "    model.add(keras.layers.Dense(dense_vector_size, activation='relu'))\n",
    "    model.add(keras.layers.Dense(1,                 activation='sigmoid'))\n",
    "\n",
    "    model.compile(optimizer = 'adam',\n",
    "                  loss      = 'binary_crossentropy',\n",
    "                  metrics   = ['accuracy'])\n",
    "    return model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 5 - Train the model\n",
    "### 5.1 - Get it"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model: \"sequential\"\n",
      "_________________________________________________________________\n",
      "Layer (type)                 Output Shape              Param #   \n",
      "=================================================================\n",
      "embedding (Embedding)        (None, 256, 32)           320000    \n",
      "_________________________________________________________________\n",
      "global_average_pooling1d (Gl (None, 32)                0         \n",
      "_________________________________________________________________\n",
      "dense (Dense)                (None, 32)                1056      \n",
      "_________________________________________________________________\n",
      "dense_1 (Dense)              (None, 1)                 33        \n",
      "=================================================================\n",
      "Total params: 321,089\n",
      "Trainable params: 321,089\n",
      "Non-trainable params: 0\n",
      "_________________________________________________________________\n"
     ]
    }
   ],
   "source": [
    "model = get_model(32)\n",
    "\n",
    "model.summary()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 5.2 - Add callback"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "os.makedirs('./run/models',   mode=0o750, exist_ok=True)\n",
    "save_dir = \"./run/models/best_model.h5\"\n",
    "savemodel_callback = tf.keras.callbacks.ModelCheckpoint(filepath=save_dir, verbose=0, save_best_only=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 5.1 - Train it"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train on 25000 samples, validate on 25000 samples\n",
      "Epoch 1/30\n",
      "25000/25000 [==============================] - 1s 57us/sample - loss: 0.6881 - accuracy: 0.6431 - val_loss: 0.6782 - val_accuracy: 0.7408\n",
      "Epoch 2/30\n",
      "25000/25000 [==============================] - 1s 32us/sample - loss: 0.6506 - accuracy: 0.7618 - val_loss: 0.6168 - val_accuracy: 0.7650\n",
      "Epoch 3/30\n",
      "25000/25000 [==============================] - 1s 31us/sample - loss: 0.5590 - accuracy: 0.8060 - val_loss: 0.5135 - val_accuracy: 0.8203\n",
      "Epoch 4/30\n",
      "25000/25000 [==============================] - 1s 31us/sample - loss: 0.4460 - accuracy: 0.8512 - val_loss: 0.4198 - val_accuracy: 0.8487\n",
      "Epoch 5/30\n",
      "25000/25000 [==============================] - 1s 31us/sample - loss: 0.3606 - accuracy: 0.8758 - val_loss: 0.3636 - val_accuracy: 0.8609\n",
      "Epoch 6/30\n",
      "25000/25000 [==============================] - 1s 31us/sample - loss: 0.3074 - accuracy: 0.8894 - val_loss: 0.3322 - val_accuracy: 0.8676\n",
      "Epoch 7/30\n",
      "25000/25000 [==============================] - 1s 31us/sample - loss: 0.2719 - accuracy: 0.9016 - val_loss: 0.3127 - val_accuracy: 0.8734\n",
      "Epoch 8/30\n",
      "25000/25000 [==============================] - 1s 31us/sample - loss: 0.2457 - accuracy: 0.9103 - val_loss: 0.3007 - val_accuracy: 0.8765\n",
      "Epoch 9/30\n",
      "25000/25000 [==============================] - 1s 30us/sample - loss: 0.2257 - accuracy: 0.9178 - val_loss: 0.2933 - val_accuracy: 0.8795\n",
      "Epoch 10/30\n",
      "25000/25000 [==============================] - 1s 30us/sample - loss: 0.2083 - accuracy: 0.9249 - val_loss: 0.2888 - val_accuracy: 0.8818\n",
      "Epoch 11/30\n",
      "25000/25000 [==============================] - 1s 30us/sample - loss: 0.1938 - accuracy: 0.9310 - val_loss: 0.2874 - val_accuracy: 0.8824\n",
      "Epoch 12/30\n",
      "25000/25000 [==============================] - 1s 31us/sample - loss: 0.1818 - accuracy: 0.9352 - val_loss: 0.2867 - val_accuracy: 0.8826\n",
      "Epoch 13/30\n",
      "25000/25000 [==============================] - 1s 30us/sample - loss: 0.1703 - accuracy: 0.9406 - val_loss: 0.2879 - val_accuracy: 0.8828\n",
      "Epoch 14/30\n",
      "25000/25000 [==============================] - 1s 30us/sample - loss: 0.1605 - accuracy: 0.9451 - val_loss: 0.2922 - val_accuracy: 0.8815\n",
      "Epoch 15/30\n",
      "25000/25000 [==============================] - 1s 29us/sample - loss: 0.1520 - accuracy: 0.9480 - val_loss: 0.2945 - val_accuracy: 0.8828\n",
      "Epoch 16/30\n",
      "25000/25000 [==============================] - 1s 30us/sample - loss: 0.1435 - accuracy: 0.9524 - val_loss: 0.2986 - val_accuracy: 0.8821\n",
      "Epoch 17/30\n",
      "25000/25000 [==============================] - 1s 30us/sample - loss: 0.1359 - accuracy: 0.9551 - val_loss: 0.3042 - val_accuracy: 0.8803\n",
      "Epoch 18/30\n",
      "25000/25000 [==============================] - 1s 29us/sample - loss: 0.1290 - accuracy: 0.9581 - val_loss: 0.3100 - val_accuracy: 0.8783\n",
      "Epoch 19/30\n",
      "25000/25000 [==============================] - 1s 30us/sample - loss: 0.1223 - accuracy: 0.9609 - val_loss: 0.3169 - val_accuracy: 0.8772\n",
      "Epoch 20/30\n",
      "25000/25000 [==============================] - 1s 29us/sample - loss: 0.1167 - accuracy: 0.9632 - val_loss: 0.3245 - val_accuracy: 0.8747\n",
      "Epoch 21/30\n",
      "25000/25000 [==============================] - 1s 30us/sample - loss: 0.1116 - accuracy: 0.9654 - val_loss: 0.3318 - val_accuracy: 0.8756\n",
      "Epoch 22/30\n",
      "25000/25000 [==============================] - 1s 29us/sample - loss: 0.1064 - accuracy: 0.9670 - val_loss: 0.3409 - val_accuracy: 0.8743\n",
      "Epoch 23/30\n",
      "25000/25000 [==============================] - 1s 29us/sample - loss: 0.1007 - accuracy: 0.9704 - val_loss: 0.3478 - val_accuracy: 0.8728\n",
      "Epoch 24/30\n",
      "25000/25000 [==============================] - 1s 30us/sample - loss: 0.0963 - accuracy: 0.9718 - val_loss: 0.3577 - val_accuracy: 0.8718\n",
      "Epoch 25/30\n",
      "25000/25000 [==============================] - 1s 30us/sample - loss: 0.0914 - accuracy: 0.9743 - val_loss: 0.3728 - val_accuracy: 0.8670\n",
      "Epoch 26/30\n",
      "25000/25000 [==============================] - 1s 29us/sample - loss: 0.0887 - accuracy: 0.9744 - val_loss: 0.3746 - val_accuracy: 0.8697\n",
      "Epoch 27/30\n",
      "25000/25000 [==============================] - 1s 29us/sample - loss: 0.0842 - accuracy: 0.9777 - val_loss: 0.3835 - val_accuracy: 0.8683\n",
      "Epoch 28/30\n",
      "25000/25000 [==============================] - 1s 30us/sample - loss: 0.0796 - accuracy: 0.9788 - val_loss: 0.3937 - val_accuracy: 0.8670\n",
      "Epoch 29/30\n",
      "25000/25000 [==============================] - 1s 30us/sample - loss: 0.0763 - accuracy: 0.9804 - val_loss: 0.4032 - val_accuracy: 0.8658\n",
      "Epoch 30/30\n",
      "25000/25000 [==============================] - 1s 30us/sample - loss: 0.0729 - accuracy: 0.9819 - val_loss: 0.4156 - val_accuracy: 0.8648\n",
      "CPU times: user 1min 42s, sys: 7.41 s, total: 1min 50s\n",
      "Wall time: 23.4 s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "n_epochs   = 30\n",
    "batch_size = 512\n",
    "\n",
    "history = model.fit(x_train,\n",
    "                    y_train,\n",
    "                    epochs          = n_epochs,\n",
    "                    batch_size      = batch_size,\n",
    "                    validation_data = (x_test, y_test),\n",
    "                    verbose         = 1,\n",
    "                    callbacks       = [savemodel_callback])\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 6 - Evaluate\n",
    "### 6.1 - Training history"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div class=\"comment\">Saved: ./run/figs/IMDB1-02-history_0</div>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "image/png": "\n",
      "text/plain": [
       "<Figure size 576x432 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<div class=\"comment\">Saved: ./run/figs/IMDB1-02-history_1</div>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "image/png": "\n",
      "text/plain": [
       "<Figure size 576x432 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "pwk.plot_history(history, save_as='02-history')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 6.2 - Reload and evaluate best model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "x_test / loss      : 0.2867\n",
      "x_test / accuracy  : 0.8826\n"
     ]
    },
    {
     "data": {
      "text/markdown": [
       "#### Accuracy donut is :"
      ],
      "text/plain": [
       "<IPython.core.display.Markdown object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<div class=\"comment\">Saved: ./run/figs/IMDB1-03-donut</div>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "image/png": "\n",
      "text/plain": [
       "<Figure size 432x432 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/markdown": [
       "#### Confusion matrix is :"
      ],
      "text/plain": [
       "<IPython.core.display.Markdown object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<style  type=\"text/css\" >\n",
       "#T_2dfd32c2_4152_11eb_b4d1_310d57f4d2a4row0_col0,#T_2dfd32c2_4152_11eb_b4d1_310d57f4d2a4row1_col1{\n",
       "            background-color:  #ffe4c4;\n",
       "            color:  #000000;\n",
       "            font-size:  12pt;\n",
       "        }#T_2dfd32c2_4152_11eb_b4d1_310d57f4d2a4row0_col1,#T_2dfd32c2_4152_11eb_b4d1_310d57f4d2a4row1_col0{\n",
       "            background-color:  #f5f5f5;\n",
       "            color:  #000000;\n",
       "            font-size:  12pt;\n",
       "        }</style><table id=\"T_2dfd32c2_4152_11eb_b4d1_310d57f4d2a4\" ><thead>    <tr>        <th class=\"blank level0\" ></th>        <th class=\"col_heading level0 col0\" >0</th>        <th class=\"col_heading level0 col1\" >1</th>    </tr></thead><tbody>\n",
       "                <tr>\n",
       "                        <th id=\"T_2dfd32c2_4152_11eb_b4d1_310d57f4d2a4level0_row0\" class=\"row_heading level0 row0\" >0</th>\n",
       "                        <td id=\"T_2dfd32c2_4152_11eb_b4d1_310d57f4d2a4row0_col0\" class=\"data row0 col0\" >0.88</td>\n",
       "                        <td id=\"T_2dfd32c2_4152_11eb_b4d1_310d57f4d2a4row0_col1\" class=\"data row0 col1\" >0.12</td>\n",
       "            </tr>\n",
       "            <tr>\n",
       "                        <th id=\"T_2dfd32c2_4152_11eb_b4d1_310d57f4d2a4level0_row1\" class=\"row_heading level0 row1\" >1</th>\n",
       "                        <td id=\"T_2dfd32c2_4152_11eb_b4d1_310d57f4d2a4row1_col0\" class=\"data row1 col0\" >0.11</td>\n",
       "                        <td id=\"T_2dfd32c2_4152_11eb_b4d1_310d57f4d2a4row1_col1\" class=\"data row1 col1\" >0.89</td>\n",
       "            </tr>\n",
       "    </tbody></table>"
      ],
      "text/plain": [
       "<pandas.io.formats.style.Styler at 0x7f943eb08dd0>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<div class=\"comment\">Saved: ./run/figs/IMDB1-04-confusion-matrix</div>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "image/png": "\n",
      "text/plain": [
       "<Figure size 576x576 with 2 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "model = keras.models.load_model('./run/models/best_model.h5')\n",
    "\n",
    "# ---- Evaluate\n",
    "score  = model.evaluate(x_test, y_test, verbose=0)\n",
    "\n",
    "print('x_test / loss      : {:5.4f}'.format(score[0]))\n",
    "print('x_test / accuracy  : {:5.4f}'.format(score[1]))\n",
    "\n",
    "values=[score[1], 1-score[1]]\n",
    "pwk.plot_donut(values,[\"Accuracy\",\"Errors\"], title=\"#### Accuracy donut is :\", save_as='03-donut')\n",
    "\n",
    "# ---- Confusion matrix\n",
    "\n",
    "y_sigmoid = model.predict(x_test)\n",
    "\n",
    "y_pred = y_sigmoid.copy()\n",
    "y_pred[ y_sigmoid< 0.5 ] = 0\n",
    "y_pred[ y_sigmoid>=0.5 ] = 1    \n",
    "\n",
    "pwk.display_confusion_matrix(y_test,y_pred,labels=range(2))\n",
    "pwk.plot_confusion_matrix(y_test,y_pred,range(2), figsize=(8, 8),normalize=False, save_as='04-confusion-matrix')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "End time is : Friday 18 December 2020, 17:52:43\n",
      "Duration is : 00:00:37 585ms\n",
      "This notebook ends here\n"
     ]
    }
   ],
   "source": [
    "pwk.end()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "---\n",
    "<img width=\"80px\" src=\"../fidle/img/00-Fidle-logo-01.svg\"></img>"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}