{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "German Traffic Sign Recognition Benchmark (GTSRB)\n",
    "=================================================\n",
    "---\n",
    "Introduction au Deep Learning  (IDLE) - S. Aria, E. Maldonado, JL. Parouty - CNRS/SARI/DEVLOG - 2020\n",
    "\n",
    "## Episode 2 : First Convolutions\n",
    "\n",
    "Our main steps:\n",
    " - Read dataset\n",
    " - Build a model\n",
    " - Train the model\n",
    " - Model evaluation\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "## 1/ Import and init"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "from tensorflow import keras\n",
    "from tensorflow.keras.callbacks import TensorBoard\n",
    "\n",
    "import numpy as np\n",
    "import matplotlib\n",
    "import matplotlib.pyplot as plt\n",
    "import time\n",
    "\n",
    "import idle.pwk as ooo\n",
    "from importlib import reload\n",
    "\n",
    "ooo.init()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2/ Reload dataset\n",
    "Dataset is one of the saved dataset: RGB25, RGB35, L25, L35, etc.  \n",
    "First of all, we're going to use the dataset : **L25**  \n",
    "(with a GPU, it only takes 35'' compared to more than 5' with a CPU !)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "\n",
    "dataset ='set-24x24-L'\n",
    "img_lx  = 24\n",
    "img_ly  = 24\n",
    "img_lz  = 1\n",
    "\n",
    "# ---- Read dataset\n",
    "x_train = np.load('./data/{}/x_train.npy'.format(dataset))\n",
    "y_train = np.load('./data/{}/y_train.npy'.format(dataset))\n",
    "\n",
    "x_test  = np.load('./data/{}/x_test.npy'.format(dataset))\n",
    "y_test  = np.load('./data/{}/y_test.npy'.format(dataset))\n",
    "\n",
    "# ---- Reshape data\n",
    "x_train = x_train.reshape( x_train.shape[0], img_lx, img_ly, img_lz)\n",
    "x_test  = x_test.reshape(  x_test.shape[0],  img_lx, img_ly, img_lz)\n",
    "\n",
    "input_shape = (img_lx, img_ly, img_lz)\n",
    "\n",
    "print(\"Dataset loaded, size={:.1f} Mo\\n\".format(ooo.get_directory_size('./data/'+dataset)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3/ Have a look to the dataset\n",
    "Note: Data must be reshape for matplotlib"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"x_train : \", x_train.shape)\n",
    "print(\"y_train : \", y_train.shape)\n",
    "print(\"x_test  : \", x_test.shape)\n",
    "print(\"y_test  : \", y_test.shape)\n",
    "\n",
    "if img_lz>1:\n",
    "    ooo.plot_images(x_train.reshape(-1,img_lx,img_ly,img_lz), y_train, range(6),  columns=3,  x_size=4, y_size=3)\n",
    "    ooo.plot_images(x_train.reshape(-1,img_lx,img_ly,img_lz), y_train, range(36), columns=12, x_size=1, y_size=1)\n",
    "else:\n",
    "    ooo.plot_images(x_train.reshape(-1,img_lx,img_ly), y_train, range(6),  columns=6,  x_size=2, y_size=2)\n",
    "    ooo.plot_images(x_train.reshape(-1,img_lx,img_ly), y_train, range(36), columns=12, x_size=1, y_size=1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 4/ Create model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_size  =  64\n",
    "num_classes =  43\n",
    "epochs      =  5"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = keras.models.Sequential()\n",
    "model.add( keras.layers.Conv2D(96, (3,3), activation='relu', input_shape=(img_lx, img_ly, img_lz)))\n",
    "model.add( keras.layers.MaxPooling2D((2, 2)))\n",
    "model.add( keras.layers.Conv2D(192, (3, 3), activation='relu'))\n",
    "model.add( keras.layers.MaxPooling2D((2, 2)))\n",
    "model.add( keras.layers.Flatten()) \n",
    "model.add( keras.layers.Dense(3072, activation='relu'))\n",
    "model.add( keras.layers.Dense(500, activation='relu'))\n",
    "model.add( keras.layers.Dense(500, activation='relu'))\n",
    "model.add( keras.layers.Dense(43, activation='softmax'))\n",
    "model.summary()\n",
    "\n",
    "model.compile(optimizer='adam',\n",
    "              loss='sparse_categorical_crossentropy',\n",
    "              metrics=['accuracy'])\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 5/ Run model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "\n",
    "history = model.fit(  x_train, y_train,\n",
    "                      batch_size=batch_size,\n",
    "                      epochs=epochs,\n",
    "                      verbose=1,\n",
    "                      validation_data=(x_test, y_test))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 6/ Evaluation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "score = model.evaluate(x_test, y_test, verbose=0)\n",
    "\n",
    "print('Test loss      : {:5.4f}'.format(score[0]))\n",
    "print('Test accuracy  : {:5.4f}'.format(score[1]))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "---\n",
    "### Results :  \n",
    "L25 : size=250 Mo, 93.15%  \n",
    "..."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}