Newer
Older
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<img width=\"800px\" src=\"../fidle/img/00-Fidle-header-01.svg\"></img>\n",
"\n",
"# <!-- TITLE --> [MNIST1] - Simple classification with DNN\n",
"<!-- DESC --> An example of classification using a dense neural network for the famous MNIST dataset\n",
"<!-- AUTHOR : Jean-Luc Parouty (CNRS/SIMaP) -->\n",
"\n",
"## Objectives :\n",
" - Recognizing handwritten numbers\n",
" - Understanding the principle of a classifier DNN network \n",
" - Implementation with Keras \n",
"\n",
"\n",
"The [MNIST dataset](http://yann.lecun.com/exdb/mnist/) (Modified National Institute of Standards and Technology) is a must for Deep Learning. \n",
"It consists of 60,000 small images of handwritten numbers for learning and 10,000 for testing.\n",
"\n",
"\n",
" - Retrieve data\n",
" - Preparing the data\n",
" - Create a model\n",
" - Train the model\n",
" - Evaluate the result\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Step 1 - Init python stuff"
]
},
{
"cell_type": "code",
"metadata": {},
"source": [
"import tensorflow as tf\n",
"from tensorflow import keras\n",
"\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"import sys,os\n",
"from importlib import reload\n",
"sys.path.append('..')\n",
"import fidle.pwk as pwk\n",
"datasets_dir = pwk.init('MNIST1')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Step 2 - Retrieve data\n",
"MNIST is one of the most famous historic dataset. \n",
"Include in [Keras datasets](https://www.tensorflow.org/api_docs/python/tf/keras/datasets)"
]
},
{
"cell_type": "code",
"metadata": {},
"source": [
"(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()\n",
"\n",
"print(\"x_train : \",x_train.shape)\n",
"print(\"y_train : \",y_train.shape)\n",
"print(\"x_test : \",x_test.shape)\n",
"print(\"y_test : \",y_test.shape)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Step 3 - Preparing the data"
]
},
{
"cell_type": "code",
"metadata": {},
"source": [
"print('Before normalization : Min={}, max={}'.format(x_train.min(),x_train.max()))\n",
"\n",
"xmax=x_train.max()\n",
"x_train = x_train / xmax\n",
"x_test = x_test / xmax\n",
"\n",
"print('After normalization : Min={}, max={}'.format(x_train.min(),x_train.max()))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Have a look"
]
},
{
"cell_type": "code",
"metadata": {},
"source": [
"pwk.plot_images(x_train, y_train, [27], x_size=5,y_size=5, colorbar=True, save_as='01-one-digit')\n",
"pwk.plot_images(x_train, y_train, range(5,41), columns=12, save_as='02-many-digits')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Step 4 - Create model\n",
"About informations about : \n",
" - [Optimizer](https://www.tensorflow.org/api_docs/python/tf/keras/optimizers)\n",
" - [Activation](https://www.tensorflow.org/api_docs/python/tf/keras/activations)\n",
" - [Loss](https://www.tensorflow.org/api_docs/python/tf/keras/losses)\n",
" - [Metrics](https://www.tensorflow.org/api_docs/python/tf/keras/metrics)"
]
},
{
"cell_type": "code",
"metadata": {},
"outputs": [],
"source": [
"hidden1 = 100\n",
"hidden2 = 100\n",
"\n",
"model = keras.Sequential([\n",
" keras.layers.Input((28,28)),\n",
" keras.layers.Flatten(),\n",
" keras.layers.Dense( hidden1, activation='relu'),\n",
" keras.layers.Dense( hidden2, activation='relu'),\n",
" keras.layers.Dense( 10, activation='softmax')\n",
"])\n",
"\n",
"model.compile(optimizer='adam',\n",
" loss='sparse_categorical_crossentropy',\n",
" metrics=['accuracy'])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Step 5 - Train the model"
]
},
{
"cell_type": "code",
"metadata": {},
"source": [
"batch_size = 512\n",
"epochs = 16\n",
"\n",
"history = model.fit( x_train, y_train,\n",
" batch_size = batch_size,\n",
" epochs = epochs,\n",
" verbose = 1,\n",
" validation_data = (x_test, y_test))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Step 6 - Evaluate\n",
"### 6.1 - Final loss and accuracy"
]
},
{
"cell_type": "code",
"metadata": {},
"source": [
"score = model.evaluate(x_test, y_test, verbose=0)\n",
"\n",
"print('Test loss :', score[0])\n",
"print('Test accuracy :', score[1])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
]
},
{
"cell_type": "code",
"metadata": {},
"source": [
"pwk.plot_history(history, figsize=(6,4), save_as='03-history')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
]
},
{
"cell_type": "code",
"metadata": {},
"source": [
"#y_pred = model.predict_classes(x_test) Deprecated after 01/01/2021 !!\n",
"\n",
"y_sigmoid = model.predict(x_test)\n",
"y_pred = np.argmax(y_sigmoid, axis=-1)\n",
"\n",
"pwk.plot_images(x_test, y_test, range(0,200), columns=12, x_size=1, y_size=1, y_pred=y_pred, save_as='04-predictions')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 6.4 - Plot some errors"
]
},
{
"cell_type": "code",
"metadata": {},
"source": [
"errors=[ i for i in range(len(x_test)) if y_pred[i]!=y_test[i] ]\n",
"errors=errors[:min(24,len(errors))]\n",
"pwk.plot_images(x_test, y_test, errors[:15], columns=6, x_size=2, y_size=2, y_pred=y_pred, save_as='05-some-errors')"
]
},
"pwk.plot_confusion_matrix(y_test,y_pred,range(10),normalize=True, save_as='06-confusion-matrix')"
]
},
{
"cell_type": "code",
"source": [
"pwk.end()"
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<div class=\"todo\">\n",
" A few things you can do for fun:\n",
" <ul>\n",
" <li>Changing the network architecture (layers, number of neurons, etc.)</li>\n",
" <li>Display a summary of the network</li>\n",
" <li>Retrieve and display the softmax output of the network, to evaluate its \"doubts\".</li>\n",
" </ul>\n",
"</div>"
]
},
{
"metadata": {},
"source": [
"---\n",
"<img width=\"80px\" src=\"../fidle/img/00-Fidle-logo-01.svg\"></img>"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
}
},
"nbformat": 4,
"nbformat_minor": 4
}