Newer
Older
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<img width=\"800px\" src=\"../fidle/img/00-Fidle-header-01.svg\"></img>\n",
"# <!-- TITLE --> [IMDB1] - Text embedding with IMDB\n",
"<!-- DESC --> A very classical example of word embedding for text classification (sentiment analysis)\n",
"<!-- AUTHOR : Jean-Luc Parouty (CNRS/SIMaP) -->\n",
"## Objectives :\n",
" - The objective is to guess whether film reviews are **positive or negative** based on the analysis of the text. \n",
" - Understand the management of **textual data** and **sentiment analysis**\n",
"\n",
"Original dataset can be find **[there](http://ai.stanford.edu/~amaas/data/sentiment/)** \n",
"Note that [IMDb.com](https://imdb.com) offers several easy-to-use [datasets](https://www.imdb.com/interfaces/) \n",
"For simplicity's sake, we'll use the dataset directly [embedded in Keras](https://www.tensorflow.org/api_docs/python/tf/keras/datasets)\n",
"\n",
"## What we're going to do :\n",
"\n",
" - Retrieve data\n",
" - Preparing the data\n",
" - Build a model\n",
" - Train the model\n",
" - Evaluate the result\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"execution_count": 1,
"outputs": [
{
"data": {
"text/html": [
"<style>\n",
"\n",
"div.warn { \n",
" background-color: #fcf2f2;\n",
" border-color: #dFb5b4;\n",
" border-left: 5px solid #dfb5b4;\n",
" padding: 0.5em;\n",
" font-weight: bold;\n",
" font-size: 1.1em;;\n",
" }\n",
"\n",
"\n",
"\n",
"div.nota { \n",
" background-color: #DAFFDE;\n",
" border-left: 5px solid #92CC99;\n",
" padding: 0.5em;\n",
" }\n",
"\n",
"div.todo:before { content:url();\n",
" float:left;\n",
" margin-right:20px;\n",
" margin-top:-20px;\n",
" margin-bottom:20px;\n",
"}\n",
"div.todo{\n",
" font-weight: bold;\n",
" font-size: 1.1em;\n",
" margin-top:40px;\n",
"}\n",
"div.todo ul{\n",
" margin: 0.2em;\n",
"}\n",
"div.todo li{\n",
" margin-left:60px;\n",
" margin-top:0;\n",
" margin-bottom:0;\n",
"}\n",
"\n",
"\n",
"</style>\n",
"\n"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"FIDLE 2020 - Practical Work Module\n",
"Version : 0.57 DEV\n",
"Run time : Thursday 10 September 2020, 16:34:04\n",
"TensorFlow version : 2.2.0\n",
"Keras version : 2.3.0-tf\n",
"Current place : Fidle at IDRIS\n",
"Dataset dir : /gpfswork/rech/mlh/commun/datasets\n",
"Update keras cache : Done\n"
"\n",
"import tensorflow as tf\n",
"import tensorflow.keras as keras\n",
"import tensorflow.keras.datasets.imdb as imdb\n",
"\n",
"import matplotlib.pyplot as plt\n",
"import matplotlib\n",
"sys.path.append('..')\n",
"import fidle.pwk as ooo\n",
"\n",
{
"cell_type": "markdown",
"metadata": {},
"source": [
"\n",
"**From Keras :**\n",
"This IMDb dataset can bet get directly from [Keras datasets](https://www.tensorflow.org/api_docs/python/tf/keras/datasets) \n",
"\n",
"Due to their nature, textual data can be somewhat complex.\n",
"\n",
"The dataset is composed of 2 parts: **reviews** and **opinions** (positive/negative), with a **dictionary**\n",
"\n",
" - dataset = (reviews, opinions)\n",
" - reviews = \\[ review_0, review_1, ...\\]\n",
" - review_i = [ int1, int2, ...] where int_i is the index of the word in the dictionary.\n",
" - opinions = \\[ int0, int1, ...\\] where int_j == 0 if opinion is negative or 1 if opinion is positive.\n",
" - dictionary = \\[ mot1:int1, mot2:int2, ... ]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"For simplicity, we will use a pre-formatted dataset. \n",
"See : https://www.tensorflow.org/api_docs/python/tf/keras/datasets/imdb/load_data \n",
"\n",
"However, Keras offers some usefull tools for formatting textual data. \n",
"See : https://www.tensorflow.org/api_docs/python/tf/keras/preprocessing/text"
]
},
{
"cell_type": "code",
"metadata": {},
"outputs": [],
"source": [
"\n",
"# ----- Retrieve x,y\n",
"#\n",
"# Choose if you want to load dataset directly from keras (small size <20M)\n",
"(x_train, y_train), (x_test, y_test) = imdb.load_data( num_words = vocab_size,\n",
" skip_top = 0,\n",
" maxlen = None,\n",
" seed = 42,\n",
" start_char = 1,\n",
" oov_char = 2,\n",
" index_from = 3, )\n",
"# Or you can use the same pre-loaded dataset\n",
"# with h5py.File(f'{datasets_dir}/IMDB/origine/dataset_imdb.h5','r') as f:\n",
"# x_train = f['x_train'][:]\n",
"# y_train = f['y_train'][:]\n",
"# x_test = f['x_test'][:]\n",
"# y_test = f['y_test'][:]"
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" Max(x_train,x_test) : 9999\n",
" x_train : (25000,) y_train : (25000,)\n",
" x_test : (25000,) y_test : (25000,)\n",
"\n",
"Review example (x_train[12]) :\n",
"\n",
" [1, 14, 22, 1367, 53, 206, 159, 4, 636, 898, 74, 26, 11, 436, 363, 108, 7, 14, 432, 14, 22, 9, 1055, 34, 8599, 2, 5, 381, 3705, 4509, 14, 768, 47, 839, 25, 111, 1517, 2579, 1991, 438, 2663, 587, 4, 280, 725, 6, 58, 11, 2714, 201, 4, 206, 16, 702, 5, 5176, 19, 480, 5920, 157, 13, 64, 219, 4, 2, 11, 107, 665, 1212, 39, 4, 206, 4, 65, 410, 16, 565, 5, 24, 43, 343, 17, 5602, 8, 169, 101, 85, 206, 108, 8, 3008, 14, 25, 215, 168, 18, 6, 2579, 1991, 438, 2, 11, 129, 1609, 36, 26, 66, 290, 3303, 46, 5, 633, 115, 4363]\n"
]
}
],
"print(\" Max(x_train,x_test) : \", ooo.rmax([x_train,x_test]) )\n",
"print(\" x_train : {} y_train : {}\".format(x_train.shape, y_train.shape))\n",
"print(\" x_test : {} y_test : {}\".format(x_test.shape, y_test.shape))\n",
"\n",
"print('\\nReview example (x_train[12]) :\\n\\n',x_train[12])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"When we loaded the dataset, we asked for using \\<start\\> as 1, \\<unknown word\\> as 2 \n",
"So, we shifted the dataset by 3 with the parameter index_from=3"
]
},
{
"cell_type": "code",
"# ---- Retrieve dictionary {word:index}, and encode it in ascii\n",
"word_index = imdb.get_word_index()\n",
"\n",
"# ---- Shift the dictionary from +3\n",
"word_index = {w:(i+3) for w,i in word_index.items()}\n",
"\n",
"# ---- Add <pad>, <start> and unknown tags\n",
"word_index.update( {'<pad>':0, '<start>':1, '<unknown>':2} )\n",
"\n",
"# ---- Create a reverse dictionary : {index:word}\n",
"index_word = {index:word for word,index in word_index.items()} \n",
"\n",
"# ---- Add a nice function to transpose :\n",
"#\n",
"def dataset2text(review):\n",
" return ' '.join([index_word.get(i, '?') for i in review])"
]
},
{
"cell_type": "code",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Dictionary size : 88587\n",
"440 : hope\n",
"441 : entertaining\n",
"442 : she's\n",
"443 : mr\n",
"444 : overall\n",
"445 : evil\n",
"446 : called\n",
"447 : loved\n",
"448 : based\n",
"449 : oh\n",
"450 : several\n",
"451 : fans\n",
"452 : mother\n",
"453 : drama\n",
"454 : beginning\n",
"\n",
"Review example (x_train[12]) :\n",
"\n",
" [1, 14, 22, 1367, 53, 206, 159, 4, 636, 898, 74, 26, 11, 436, 363, 108, 7, 14, 432, 14, 22, 9, 1055, 34, 8599, 2, 5, 381, 3705, 4509, 14, 768, 47, 839, 25, 111, 1517, 2579, 1991, 438, 2663, 587, 4, 280, 725, 6, 58, 11, 2714, 201, 4, 206, 16, 702, 5, 5176, 19, 480, 5920, 157, 13, 64, 219, 4, 2, 11, 107, 665, 1212, 39, 4, 206, 4, 65, 410, 16, 565, 5, 24, 43, 343, 17, 5602, 8, 169, 101, 85, 206, 108, 8, 3008, 14, 25, 215, 168, 18, 6, 2579, 1991, 438, 2, 11, 129, 1609, 36, 26, 66, 290, 3303, 46, 5, 633, 115, 4363]\n",
"\n",
"In real words :\n",
"\n",
" <start> this film contains more action before the opening credits than are in entire hollywood films of this sort this film is produced by tsui <unknown> and stars jet li this team has brought you many worthy hong kong cinema productions including the once upon a time in china series the action was fast and furious with amazing wire work i only saw the <unknown> in two shots aside from the action the story itself was strong and not just used as filler to find any other action films to rival this you must look for a hong kong cinema <unknown> in your area they are really worth checking out and usually never disappoint\n"
]
}
],
"source": [
"print('\\nDictionary size : ', len(word_index))\n",
"for k in range(440,455):print(f'{k:2d} : {index_word[k]}' )\n",
"print('\\nReview example (x_train[12]) :\\n\\n',x_train[12])\n",
"print('\\nIn real words :\\n\\n', dataset2text(x_train[12]))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"outputs": [
{
"data": {
"image/png": "\n",
"<Figure size 864x432 with 1 Axes>"
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"ax=sns.histplot([len(i) for i in x_train],bins=60)\n",
"ax.set_title('Distribution of reviews by size')\n",
"plt.xlabel(\"Review's sizes\")\n",
"plt.ylabel('Density')\n",
"ax.set_xlim(0, 1500)\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Step 3 - Preprocess the data (padding)\n",
"In order to be processed by an NN, all entries must have the same length. \n",
"We will therefore complete them with a padding (of \\<pad\\>\\) "
]
},
{
"cell_type": "code",
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Review example (x_train[12]) :\n",
"\n",
" [ 1 14 22 1367 53 206 159 4 636 898 74 26 11 436\n",
" 363 108 7 14 432 14 22 9 1055 34 8599 2 5 381\n",
" 3705 4509 14 768 47 839 25 111 1517 2579 1991 438 2663 587\n",
" 4 280 725 6 58 11 2714 201 4 206 16 702 5 5176\n",
" 19 480 5920 157 13 64 219 4 2 11 107 665 1212 39\n",
" 4 206 4 65 410 16 565 5 24 43 343 17 5602 8\n",
" 169 101 85 206 108 8 3008 14 25 215 168 18 6 2579\n",
" 1991 438 2 11 129 1609 36 26 66 290 3303 46 5 633\n",
" 115 4363 0 0 0 0 0 0 0 0 0 0 0 0\n",
" 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
" 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
" 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
" 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
" 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
" 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
" 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
" 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
" 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
" 0 0 0 0]\n",
"\n",
"In real words :\n",
"\n",
" <start> this film contains more action before the opening credits than are in entire hollywood films of this sort this film is produced by tsui <unknown> and stars jet li this team has brought you many worthy hong kong cinema productions including the once upon a time in china series the action was fast and furious with amazing wire work i only saw the <unknown> in two shots aside from the action the story itself was strong and not just used as filler to find any other action films to rival this you must look for a hong kong cinema <unknown> in your area they are really worth checking out and usually never disappoint <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad>\n"
]
}
],
"\n",
"x_train = keras.preprocessing.sequence.pad_sequences(x_train,\n",
" value = 0,\n",
" padding = 'post',\n",
"x_test = keras.preprocessing.sequence.pad_sequences(x_test,\n",
" value = 0 ,\n",
" padding = 'post',\n",
"\n",
"print('\\nReview example (x_train[12]) :\\n\\n',x_train[12])\n",
"print('\\nIn real words :\\n\\n', dataset2text(x_train[12]))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Save dataset and dictionary (For future use but not mandatory if at GRICAD or IDRIS)"
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Saved.\n"
]
}
],
"# ---- To write h5 dataset in a test place (optional)\n",
"# For small tests only !\n",
"#\n",
"output_dir = './data'\n",
"ooo.mkdir(output_dir)\n",
"with h5py.File(f'{output_dir}/dataset_imdb.h5', 'w') as f:\n",
" f.create_dataset(\"x_train\", data=x_train)\n",
" f.create_dataset(\"y_train\", data=y_train)\n",
" f.create_dataset(\"x_test\", data=x_test)\n",
" f.create_dataset(\"y_test\", data=y_test)\n",
"\n",
"with open(f'{output_dir}/word_index.json', 'w') as fp:\n",
"with open(f'{output_dir}/index_word.json', 'w') as fp:\n",
" json.dump(index_word, fp)\n",
"\n",
"print('Saved.')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Step 4 - Build the model\n",
"1. We'll choose a dense vector size for the embedding output with **dense_vector_size**\n",
"2. **GlobalAveragePooling1D** do a pooling on the last dimension : (None, lx, ly) -> (None, ly) \n",
"In other words: we average the set of vectors/words of a sentence\n",
"3. L'embedding de Keras fonctionne de manière supervisée. Il s'agit d'une couche de *vocab_size* neurones vers *n_neurons* permettant de maintenir une table de vecteurs (les poids constituent les vecteurs). Cette couche ne calcule pas de sortie a la façon des couches normales, mais renvois la valeur des vecteurs. n mots => n vecteurs (ensuite empilés par le pooling) \n",
"Voir : https://stats.stackexchange.com/questions/324992/how-the-embedding-layer-is-trained-in-keras-embedding-layer\n",
"\n",
"A SUIVRE : https://www.liip.ch/en/blog/sentiment-detection-with-keras-word-embeddings-and-lstm-deep-learning-networks\n",
"### 4.1 - Build\n",
"More documentation about :\n",
" - [Embedding](https://www.tensorflow.org/api_docs/python/tf/keras/layers/Embedding)\n",
" - [GlobalAveragePooling1D](https://www.tensorflow.org/api_docs/python/tf/keras/layers/GlobalAveragePooling1D)"
]
},
{
"cell_type": "code",
"metadata": {},
"outputs": [],
"source": [
"def get_model(dense_vector_size=32):\n",
" \n",
" model = keras.Sequential()\n",
" model.add(keras.layers.Embedding(input_dim = vocab_size, \n",
" output_dim = dense_vector_size, \n",
" input_length = review_len))\n",
" model.add(keras.layers.GlobalAveragePooling1D())\n",
" model.add(keras.layers.Dense(dense_vector_size, activation='relu'))\n",
" model.add(keras.layers.Dense(1, activation='sigmoid'))\n",
"\n",
" model.compile(optimizer = 'adam',\n",
" loss = 'binary_crossentropy',\n",
" metrics = ['accuracy'])\n",
" return model"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Step 5 - Train the model\n",
"### 5.1 - Get it"
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model: \"sequential\"\n",
"_________________________________________________________________\n",
"Layer (type) Output Shape Param # \n",
"=================================================================\n",
"embedding (Embedding) (None, 256, 32) 320000 \n",
"_________________________________________________________________\n",
"global_average_pooling1d (Gl (None, 32) 0 \n",
"_________________________________________________________________\n",
"dense (Dense) (None, 32) 1056 \n",
"_________________________________________________________________\n",
"dense_1 (Dense) (None, 1) 33 \n",
"=================================================================\n",
"Total params: 321,089\n",
"Trainable params: 321,089\n",
"Non-trainable params: 0\n",
"_________________________________________________________________\n"
]
}
],
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"metadata": {},
"outputs": [],
"source": [
"os.makedirs('./run/models', mode=0o750, exist_ok=True)\n",
"save_dir = \"./run/models/best_model.h5\"\n",
"savemodel_callback = tf.keras.callbacks.ModelCheckpoint(filepath=save_dir, verbose=0, save_best_only=True)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 5.1 - Train it"
]
},
{
"cell_type": "code",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/30\n",
"49/49 [==============================] - 1s 23ms/step - loss: 0.6884 - accuracy: 0.6085 - val_loss: 0.6783 - val_accuracy: 0.6292\n",
"49/49 [==============================] - 1s 16ms/step - loss: 0.6497 - accuracy: 0.7349 - val_loss: 0.6142 - val_accuracy: 0.7631\n",
"49/49 [==============================] - 1s 16ms/step - loss: 0.5548 - accuracy: 0.8092 - val_loss: 0.5105 - val_accuracy: 0.8291\n",
"49/49 [==============================] - 1s 16ms/step - loss: 0.4420 - accuracy: 0.8572 - val_loss: 0.4182 - val_accuracy: 0.8461\n",
"49/49 [==============================] - 1s 17ms/step - loss: 0.3576 - accuracy: 0.8790 - val_loss: 0.3624 - val_accuracy: 0.8609\n",
"49/49 [==============================] - 1s 17ms/step - loss: 0.3047 - accuracy: 0.8928 - val_loss: 0.3303 - val_accuracy: 0.8690\n",
"49/49 [==============================] - 1s 17ms/step - loss: 0.2689 - accuracy: 0.9030 - val_loss: 0.3116 - val_accuracy: 0.8744\n",
"49/49 [==============================] - 1s 16ms/step - loss: 0.2426 - accuracy: 0.9122 - val_loss: 0.3002 - val_accuracy: 0.8775\n",
"49/49 [==============================] - 1s 17ms/step - loss: 0.2219 - accuracy: 0.9204 - val_loss: 0.2924 - val_accuracy: 0.8802\n",
"49/49 [==============================] - 1s 16ms/step - loss: 0.2054 - accuracy: 0.9275 - val_loss: 0.2883 - val_accuracy: 0.8826\n",
"49/49 [==============================] - 1s 16ms/step - loss: 0.1910 - accuracy: 0.9310 - val_loss: 0.2881 - val_accuracy: 0.8820\n",
"49/49 [==============================] - 1s 17ms/step - loss: 0.1784 - accuracy: 0.9374 - val_loss: 0.2873 - val_accuracy: 0.8830\n",
"49/49 [==============================] - 1s 16ms/step - loss: 0.1663 - accuracy: 0.9423 - val_loss: 0.2919 - val_accuracy: 0.8810\n",
"49/49 [==============================] - 1s 16ms/step - loss: 0.1570 - accuracy: 0.9472 - val_loss: 0.2921 - val_accuracy: 0.8826\n",
"49/49 [==============================] - 1s 16ms/step - loss: 0.1478 - accuracy: 0.9508 - val_loss: 0.2968 - val_accuracy: 0.8816\n",
"49/49 [==============================] - 1s 16ms/step - loss: 0.1401 - accuracy: 0.9532 - val_loss: 0.3015 - val_accuracy: 0.8815\n",
"49/49 [==============================] - 1s 16ms/step - loss: 0.1326 - accuracy: 0.9566 - val_loss: 0.3077 - val_accuracy: 0.8782\n",
"49/49 [==============================] - 1s 16ms/step - loss: 0.1254 - accuracy: 0.9595 - val_loss: 0.3155 - val_accuracy: 0.8795\n",
"49/49 [==============================] - 1s 16ms/step - loss: 0.1192 - accuracy: 0.9627 - val_loss: 0.3214 - val_accuracy: 0.8761\n",
"49/49 [==============================] - 1s 16ms/step - loss: 0.1129 - accuracy: 0.9647 - val_loss: 0.3292 - val_accuracy: 0.8745\n",
"49/49 [==============================] - 1s 16ms/step - loss: 0.1074 - accuracy: 0.9672 - val_loss: 0.3374 - val_accuracy: 0.8735\n",
"49/49 [==============================] - 1s 16ms/step - loss: 0.1019 - accuracy: 0.9698 - val_loss: 0.3483 - val_accuracy: 0.8698\n",
"49/49 [==============================] - 1s 16ms/step - loss: 0.0975 - accuracy: 0.9715 - val_loss: 0.3538 - val_accuracy: 0.8726\n",
"49/49 [==============================] - 1s 16ms/step - loss: 0.0924 - accuracy: 0.9731 - val_loss: 0.3640 - val_accuracy: 0.8701\n",
"49/49 [==============================] - 1s 16ms/step - loss: 0.0883 - accuracy: 0.9755 - val_loss: 0.3732 - val_accuracy: 0.8703\n",
"49/49 [==============================] - 1s 16ms/step - loss: 0.0838 - accuracy: 0.9776 - val_loss: 0.3825 - val_accuracy: 0.8686\n",
"49/49 [==============================] - 1s 16ms/step - loss: 0.0800 - accuracy: 0.9792 - val_loss: 0.3939 - val_accuracy: 0.8678\n",
"49/49 [==============================] - 1s 16ms/step - loss: 0.0765 - accuracy: 0.9807 - val_loss: 0.4021 - val_accuracy: 0.8666\n",
"49/49 [==============================] - 1s 16ms/step - loss: 0.0730 - accuracy: 0.9821 - val_loss: 0.4218 - val_accuracy: 0.8646\n",
"49/49 [==============================] - 1s 16ms/step - loss: 0.0696 - accuracy: 0.9835 - val_loss: 0.4242 - val_accuracy: 0.8646\n",
"CPU times: user 37.1 s, sys: 1.5 s, total: 38.6 s\n",
"Wall time: 25.7 s\n"
"source": [
"%%time\n",
"\n",
"n_epochs = 30\n",
"batch_size = 512\n",
"\n",
"history = model.fit(x_train,\n",
" y_train,\n",
" epochs = n_epochs,\n",
" batch_size = batch_size,\n",
" validation_data = (x_test, y_test),\n",
" verbose = 1,\n",
" callbacks = [savemodel_callback])\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Step 6 - Evaluate\n",
"### 6.1 - Training history"
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 576x432 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 576x432 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"ooo.plot_history(history)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 6.2 - Reload and evaluate best model"
]
},
{
"cell_type": "code",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"x_test / loss : 0.2873\n",
"x_test / accuracy : 0.8830\n"
]
},
{
"data": {
"text/markdown": [
"#### Accuracy donut is :"
],
"text/plain": [
"<IPython.core.display.Markdown object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x432 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/markdown": [
"#### Confusion matrix is :"
],
"text/plain": [
"<IPython.core.display.Markdown object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<style type=\"text/css\" >\n",
" #T_08c4932e_f373_11ea_b3b0_0cc47af5c729row0_col0 {\n",
" background-color: #ebf3eb;\n",
" color: #000000;\n",
" } #T_08c4932e_f373_11ea_b3b0_0cc47af5c729row0_col1 {\n",
" background-color: #ebf3eb;\n",
" color: #000000;\n",
" font-size: 12pt;\n",
" } #T_08c4932e_f373_11ea_b3b0_0cc47af5c729row1_col0 {\n",
" background-color: #ebf3eb;\n",
" color: #000000;\n",
" font-size: 12pt;\n",
" } #T_08c4932e_f373_11ea_b3b0_0cc47af5c729row1_col1 {\n",
" background-color: #ebf3eb;\n",
" color: #000000;\n",
" }</style><table id=\"T_08c4932e_f373_11ea_b3b0_0cc47af5c729\" ><thead> <tr> <th class=\"blank level0\" ></th> <th class=\"col_heading level0 col0\" >0</th> <th class=\"col_heading level0 col1\" >1</th> </tr></thead><tbody>\n",
" <th id=\"T_08c4932e_f373_11ea_b3b0_0cc47af5c729level0_row0\" class=\"row_heading level0 row0\" >0</th>\n",
" <td id=\"T_08c4932e_f373_11ea_b3b0_0cc47af5c729row0_col0\" class=\"data row0 col0\" >1.00</td>\n",
" <td id=\"T_08c4932e_f373_11ea_b3b0_0cc47af5c729row0_col1\" class=\"data row0 col1\" >0.00</td>\n",
" <th id=\"T_08c4932e_f373_11ea_b3b0_0cc47af5c729level0_row1\" class=\"row_heading level0 row1\" >1</th>\n",
" <td id=\"T_08c4932e_f373_11ea_b3b0_0cc47af5c729row1_col0\" class=\"data row1 col0\" >1.00</td>\n",
" <td id=\"T_08c4932e_f373_11ea_b3b0_0cc47af5c729row1_col1\" class=\"data row1 col1\" >0.00</td>\n",
" </tr>\n",
" </tbody></table>"
],
"text/plain": [
"<pandas.io.formats.style.Styler at 0x14d92668afd0>"
]
},
"metadata": {},
"output_type": "display_data"
"model = keras.models.load_model('./run/models/best_model.h5')\n",
"\n",
"# ---- Evaluate\n",
"score = model.evaluate(x_test, y_test, verbose=0)\n",
"\n",
"print('x_test / loss : {:5.4f}'.format(score[0]))\n",
"print('x_test / accuracy : {:5.4f}'.format(score[1]))\n",
"\n",
"values=[score[1], 1-score[1]]\n",
"ooo.plot_donut(values,[\"Accuracy\",\"Errors\"], title=\"#### Accuracy donut is :\")\n",
"\n",
"# ---- Confusion matrix\n",
"\n",
"#y_pred = model.predict_classes(x_test) Deprecated after 01/01/2021 !!\n",
"\n",
"y_sigmoid = model.predict(x_test)\n",
"y_pred = np.argmax(y_sigmoid, axis=-1)\n",
"# ooo.display_confusion_matrix(y_test,y_pred,labels=range(2),color='orange',font_size='20pt')\n",
"ooo.display_confusion_matrix(y_test,y_pred,labels=range(2))\n"
"<img width=\"80px\" src=\"../fidle/img/00-Fidle-logo-01.svg\"></img>"
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.7"
}
},
"nbformat": 4,
"nbformat_minor": 4
}