{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<img width=\"800px\" src=\"../fidle/img/00-Fidle-header-01.svg\"></img>\n",
    "\n",
    "# <!-- TITLE --> [GTSRB5] - Full convolutions\n",
    "<!-- DESC --> Episode 5 : A lot of models, a lot of datasets and a lot of results.\n",
    "<!-- AUTHOR : Jean-Luc Parouty (CNRS/SIMaP) -->\n",
    "\n",
    "## Objectives :\n",
    "  - Try multiple solutions\n",
    "  - Design a generic and batch-usable code\n",
    "  \n",
    "The German Traffic Sign Recognition Benchmark (GTSRB) is a dataset with more than 50,000 photos of road signs from about 40 classes.  \n",
    "The final aim is to recognise them !  \n",
    "Description is available there : http://benchmark.ini.rub.de/?section=gtsrb&subsection=dataset\n",
    "\n",
    "\n",
    "## What we're going to do :\n",
    "\n",
    "Our main steps:\n",
    " - Try n models with n datasets\n",
    " - Save a Pandas/h5 report\n",
    " - Write to be run in batch mode\n",
    "\n",
    "## Step 1 - Import and init\n",
    "### 1.1 - Python stuffs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "from tensorflow import keras\n",
    "import numpy as np\n",
    "import h5py\n",
    "import sys,os,time,json\n",
    "import random\n",
    "from IPython.display import display\n",
    "sys.path.append('..')\n",
    "import fidle.pwk as pwk\n",
    "\n",
    "VERSION='1.6'\n",
    "\n",
    "sys.path.append('..')\n",
    "import fidle.pwk as ooo\n",
    "\n",
    "run_dir = './run/GTSRB5'\n",
    "datasets_dir = pwk.init('GTSRB5', run_dir)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 1.2 - Parameters\n",
    "**Note:**   \n",
    "With a dataset of 20% and a scale of 0.1, it takes only 2-3' of a laptop...  \n",
    "With a dataset of 100% and a scale of 1, it takes 30' on a V100 GPU !\n",
    "\n",
    "`fit_verbosity` is the verbosity during training : 0 = silent, 1 = progress bar, 2 = one line per epoch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "enhanced_dir = f'./data'\n",
    "# enhanced_dir = f'{datasets_dir}/GTSRB/enhanced'\n",
    "\n",
    "# ---- For tests\n",
    "datasets      = ['set-24x24-L', 'set-24x24-RGB', 'set-48x48-RGB']\n",
    "models        = {'v1':'get_model_v1', 'v2':'get_model_v2', 'v3':'get_model_v3'}\n",
    "batch_size    = 64\n",
    "epochs        = 5\n",
    "scale         = 0.1\n",
    "with_datagen  = False\n",
    "fit_verbosity = 0\n",
    "\n",
    "# ---- All possibilities\n",
    "# datasets      = ['set-24x24-L', 'set-24x24-RGB', 'set-48x48-L', 'set-48x48-RGB', 'set-24x24-L-LHE', 'set-24x24-RGB-HE', 'set-48x48-L-LHE', 'set-48x48-RGB-HE']\n",
    "# models        = {'v1':'get_model_v1', 'v2':'get_model_v2', 'v3':'get_model_v3'}\n",
    "# batch_size    = 64\n",
    "# epochs        = 16\n",
    "# scale         = 1\n",
    "# with_datagen  = False\n",
    "# fit_verbosity = 0\n",
    "\n",
    "# ---- Data augmentation\n",
    "# datasets      = ['set-48x48-RGB']\n",
    "# models        = {'v2':'get_model_v2'}\n",
    "# batch_size    = 64\n",
    "# epochs        = 20\n",
    "# scale         = 1\n",
    "# with_datagen  = True\n",
    "# fit_verbosity = 0"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Override parameters (batch mode) - Just forget this cell"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pwk.override('enhanced_dir', 'datasets', 'models', 'batch_size', 'epochs', 'scale', 'with_datagen', 'fit_verbosity')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 2 - Start"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "random.seed(time.time())\n",
    "\n",
    "# ---- Where I am ?\n",
    "now    = time.strftime(\"%A %d %B %Y - %Hh%Mm%Ss\")\n",
    "here   = os.getcwd()\n",
    "tag_id = '{:06}'.format(random.randint(0,99999))\n",
    "\n",
    "# ---- Who I am ?\n",
    "oar_id   = os.getenv(\"OAR_JOB_ID\",  \"??\")\n",
    "slurm_id = os.getenv(\"SLURM_JOBID\", \"??\")\n",
    "\n",
    "print('Full Convolutions Notebook :')\n",
    "print('  Version            : ', VERSION  )\n",
    "print('  Now is             : ', now      )\n",
    "print('  OAR id             : ', oar_id   )\n",
    "print('  SLURM id           : ', slurm_id )\n",
    "print('  Tag id             : ', tag_id   )\n",
    "print('  Working directory  : ', here     )\n",
    "print('  Output  directory  : ', run_dir  )\n",
    "print('  for tensorboard    : ', f'--logdir {run_dir}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# ---- Uncomment for batch tests\n",
    "#\n",
    "# print(\"\\n\\n*** Test mode - Exit before making big treatments... ***\\n\\n\")\n",
    "# sys.exit()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 3 - Dataset loading"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def read_dataset(enhanced_dir, dataset_name):\n",
    "    '''Reads h5 dataset from dataset_dir\n",
    "    Args:\n",
    "        dataset_dir : datasets dir\n",
    "        name        : dataset name, without .h5\n",
    "    Returns:    x_train,y_train,x_test,y_test data'''\n",
    "    # ---- Read dataset\n",
    "    filename = f'{enhanced_dir}/{dataset_name}.h5'\n",
    "    size     = os.path.getsize(filename)/(1024*1024)\n",
    "\n",
    "    with  h5py.File(filename,'r') as f:\n",
    "        x_train = f['x_train'][:]\n",
    "        y_train = f['y_train'][:]\n",
    "        x_test  = f['x_test'][:]\n",
    "        y_test  = f['y_test'][:]\n",
    "\n",
    "    # ---- Shuffle\n",
    "    x_train,y_train=pwk.shuffle_np_dataset(x_train,y_train)\n",
    "\n",
    "    # ---- done\n",
    "    return x_train,y_train,x_test,y_test,size"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 4 - Models collection"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "# A basic model\n",
    "#\n",
    "def get_model_v1(lx,ly,lz):\n",
    "    \n",
    "    model = keras.models.Sequential()\n",
    "    \n",
    "    model.add( keras.layers.Conv2D(96, (3,3), activation='relu', input_shape=(lx,ly,lz)))\n",
    "    model.add( keras.layers.MaxPooling2D((2, 2)))\n",
    "    model.add( keras.layers.Dropout(0.2))\n",
    "\n",
    "    model.add( keras.layers.Conv2D(192, (3, 3), activation='relu'))\n",
    "    model.add( keras.layers.MaxPooling2D((2, 2)))\n",
    "    model.add( keras.layers.Dropout(0.2))\n",
    "\n",
    "    model.add( keras.layers.Flatten()) \n",
    "    model.add( keras.layers.Dense(1500, activation='relu'))\n",
    "    model.add( keras.layers.Dropout(0.5))\n",
    "\n",
    "    model.add( keras.layers.Dense(43, activation='softmax'))\n",
    "    return model\n",
    "    \n",
    "# A more sophisticated model\n",
    "#\n",
    "def get_model_v2(lx,ly,lz):\n",
    "    model = keras.models.Sequential()\n",
    "\n",
    "    model.add( keras.layers.Conv2D(64, (3, 3), padding='same', input_shape=(lx,ly,lz), activation='relu'))\n",
    "    model.add( keras.layers.Conv2D(64, (3, 3), activation='relu'))\n",
    "    model.add( keras.layers.MaxPooling2D(pool_size=(2, 2)))\n",
    "    model.add( keras.layers.Dropout(0.2))\n",
    "\n",
    "    model.add( keras.layers.Conv2D(128, (3, 3), padding='same', activation='relu'))\n",
    "    model.add( keras.layers.Conv2D(128, (3, 3), activation='relu'))\n",
    "    model.add( keras.layers.MaxPooling2D(pool_size=(2, 2)))\n",
    "    model.add( keras.layers.Dropout(0.2))\n",
    "\n",
    "    model.add( keras.layers.Conv2D(256, (3, 3), padding='same',activation='relu'))\n",
    "    model.add( keras.layers.Conv2D(256, (3, 3), activation='relu'))\n",
    "    model.add( keras.layers.MaxPooling2D(pool_size=(2, 2)))\n",
    "    model.add( keras.layers.Dropout(0.2))\n",
    "\n",
    "    model.add( keras.layers.Flatten())\n",
    "    model.add( keras.layers.Dense(512, activation='relu'))\n",
    "    model.add( keras.layers.Dropout(0.5))\n",
    "    model.add( keras.layers.Dense(43, activation='softmax'))\n",
    "    return model\n",
    "\n",
    "def get_model_v3(lx,ly,lz):\n",
    "    model = keras.models.Sequential()\n",
    "    model.add(tf.keras.layers.Conv2D(32, (5, 5), padding='same',  activation='relu', input_shape=(lx,ly,lz)))\n",
    "    model.add(tf.keras.layers.BatchNormalization(axis=-1))      \n",
    "    model.add(tf.keras.layers.MaxPooling2D(pool_size=(2, 2)))\n",
    "    model.add(tf.keras.layers.Dropout(0.2))\n",
    "\n",
    "    model.add(tf.keras.layers.Conv2D(64, (5, 5), padding='same',  activation='relu'))\n",
    "    model.add(tf.keras.layers.BatchNormalization(axis=-1))\n",
    "    model.add(tf.keras.layers.Conv2D(128, (5, 5), padding='same', activation='relu'))\n",
    "    model.add(tf.keras.layers.BatchNormalization(axis=-1))\n",
    "    model.add(tf.keras.layers.MaxPooling2D(pool_size=(2, 2)))\n",
    "    model.add(tf.keras.layers.Dropout(0.2))\n",
    "\n",
    "    model.add(tf.keras.layers.Flatten())\n",
    "    model.add(tf.keras.layers.Dense(512, activation='relu'))\n",
    "    model.add(tf.keras.layers.BatchNormalization())\n",
    "    model.add(tf.keras.layers.Dropout(0.4))\n",
    "\n",
    "    model.add(tf.keras.layers.Dense(43, activation='softmax'))\n",
    "    return model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 5 - Multiple datasets, multiple models ;-)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def multi_run(enhanced_dir, datasets, models, datagen=None,\n",
    "              scale=1, batch_size=64, epochs=16, \n",
    "              fit_verbosity=0, tag_id='last'):\n",
    "    \"\"\"\n",
    "    Launches a dataset-model combination\n",
    "    args:\n",
    "        enhanced_dir   : Directory of the enhanced datasets\n",
    "        datasets       : List of dataset (whitout .h5)\n",
    "        models         : List of model like { \"model name\":get_model(), ...}\n",
    "        datagen        : Data generator or None (None)\n",
    "        scale          : % of dataset to use.  1 mean all. (1)\n",
    "        batch_size     : Batch size (64)\n",
    "        epochs         : Number of epochs (16)\n",
    "        fit_verbosity  : Verbose level (0)\n",
    "        tag_id         : postfix for report, logs and models dir (_last)\n",
    "    return:\n",
    "        report        : Report as a dict for Pandas.\n",
    "    \"\"\"  \n",
    "    # ---- Logs and models dir\n",
    "    #\n",
    "    os.makedirs(f'{run_dir}/logs_{tag_id}',   mode=0o750, exist_ok=True)\n",
    "    os.makedirs(f'{run_dir}/models_{tag_id}', mode=0o750, exist_ok=True)\n",
    "    \n",
    "    # ---- Columns of output\n",
    "    #\n",
    "    output={}\n",
    "    output['Dataset'] = []\n",
    "    output['Size']    = []\n",
    "    for m in models:\n",
    "        output[m+'_Accuracy'] = []\n",
    "        output[m+'_Duration'] = []\n",
    "\n",
    "    # ---- Let's go\n",
    "    #\n",
    "    for d_name in datasets:\n",
    "        print(\"\\nDataset : \",d_name)\n",
    "\n",
    "        # ---- Read dataset\n",
    "        x_train,y_train,x_test,y_test, d_size = read_dataset(enhanced_dir, d_name)\n",
    "        output['Dataset'].append(d_name)\n",
    "        output['Size'].append(d_size)\n",
    "        \n",
    "        # ---- Rescale\n",
    "        x_train,y_train,x_test,y_test = pwk.rescale_dataset(x_train,y_train,x_test,y_test, scale=scale)\n",
    "        \n",
    "        # ---- Get the shape\n",
    "        (n,lx,ly,lz) = x_train.shape\n",
    "\n",
    "        # ---- For each model\n",
    "        for m_name,m_function in models.items():\n",
    "            print(\"    Run model {}  : \".format(m_name), end='')\n",
    "            # ---- get model\n",
    "            try:\n",
    "                # ---- get function by name\n",
    "                m_function=globals()[m_function]\n",
    "                model=m_function(lx,ly,lz)\n",
    "                # ---- Compile it\n",
    "                model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])\n",
    "                # ---- Callbacks tensorboard\n",
    "                log_dir = f'{run_dir}/logs_{tag_id}/tb_{d_name}_{m_name}'\n",
    "                tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1)\n",
    "                # ---- Callbacks bestmodel\n",
    "                save_dir = f'{run_dir}/models_{tag_id}/model_{d_name}_{m_name}.h5'\n",
    "                bestmodel_callback = tf.keras.callbacks.ModelCheckpoint(filepath=save_dir, verbose=0, monitor='accuracy', save_best_only=True)\n",
    "                # ---- Train\n",
    "                start_time = time.time()\n",
    "                if datagen==None:\n",
    "                    # ---- No data augmentation (datagen=None) --------------------------------------\n",
    "                    history = model.fit(x_train, y_train,\n",
    "                                        batch_size      = batch_size,\n",
    "                                        epochs          = epochs,\n",
    "                                        verbose         = fit_verbosity,\n",
    "                                        validation_data = (x_test, y_test),\n",
    "                                        callbacks       = [tensorboard_callback, bestmodel_callback])\n",
    "                else:\n",
    "                    # ---- Data augmentation (datagen given) ----------------------------------------\n",
    "                    datagen.fit(x_train)\n",
    "                    history = model.fit(datagen.flow(x_train, y_train, batch_size=batch_size),\n",
    "                                        steps_per_epoch = int(len(x_train)/batch_size),\n",
    "                                        epochs          = epochs,\n",
    "                                        verbose         = fit_verbosity,\n",
    "                                        validation_data = (x_test, y_test),\n",
    "                                        callbacks       = [tensorboard_callback, bestmodel_callback])\n",
    "                    \n",
    "                # ---- Result\n",
    "                end_time = time.time()\n",
    "                duration = end_time-start_time\n",
    "                accuracy = max(history.history[\"val_accuracy\"])*100\n",
    "                #\n",
    "                output[m_name+'_Accuracy'].append(accuracy)\n",
    "                output[m_name+'_Duration'].append(duration)\n",
    "                print(f\"Accuracy={accuracy: 7.2f}    Duration={duration: 7.2f}\")\n",
    "            except:\n",
    "                raise\n",
    "                output[m_name+'_Accuracy'].append('0')\n",
    "                output[m_name+'_Duration'].append('999')\n",
    "                print('-')\n",
    "    return output"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 6 - Run !"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pwk.chrono_start()\n",
    "\n",
    "print('\\n---- Run','-'*50)\n",
    "\n",
    "\n",
    "# ---- Data augmentation or not\n",
    "#\n",
    "if with_datagen :\n",
    "    datagen = keras.preprocessing.image.ImageDataGenerator(featurewise_center=False,\n",
    "                                                           featurewise_std_normalization=False,\n",
    "                                                           width_shift_range=0.1,\n",
    "                                                           height_shift_range=0.1,\n",
    "                                                           zoom_range=0.2,\n",
    "                                                           shear_range=0.1,\n",
    "                                                           rotation_range=10.)\n",
    "else:\n",
    "    datagen=None\n",
    "    \n",
    "# ---- Run\n",
    "#\n",
    "output = multi_run(enhanced_dir,\n",
    "                   datasets, \n",
    "                   models,\n",
    "                   datagen       = datagen,\n",
    "                   scale         = scale,\n",
    "                   batch_size    = batch_size,\n",
    "                   epochs        = epochs,\n",
    "                   fit_verbosity = fit_verbosity,\n",
    "                   tag_id        = tag_id)\n",
    "\n",
    "# ---- Save report\n",
    "#\n",
    "report={}\n",
    "report['output']=output\n",
    "report['description'] = f'scale={scale} batch_size={batch_size} epochs={epochs} data_aug={with_datagen}'\n",
    "\n",
    "report_name=f'{run_dir}/report_{tag_id}.json'\n",
    "\n",
    "with open(report_name, 'w') as file:\n",
    "    json.dump(report, file, indent=4)\n",
    "\n",
    "print('\\nReport saved as ',report_name)\n",
    "\n",
    "pwk.chrono_show()\n",
    "print('-'*59)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 7 - That's all folks.."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pwk.end()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "---\n",
    "<img width=\"80px\" src=\"../fidle/img/00-Fidle-logo-01.svg\"></img>"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}