{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "XJy9QoDC7XA7"
      },
      "source": [
        "<img width=\"800px\" src=\"../fidle/img/header.svg\"></img>\n",
        "\n",
        "# <!-- TITLE --> [DRL2] - RL Baselines3 Zoo: Training in Colab\n",
        "<!-- DESC --> Demo of Stable baseline3 with Colab\n",
        "<!-- AUTHOR : Nathan Cassereau (IDRIS) and Bertrand Cabot (IDRIS) -->\n",
        "\n",
        "\n",
        "Demo of Stable baseline3 adapted By Nathan Cassereau (IDRIS) and Bertrand Cabot (IDRIS)\n",
        "\n",
        "\n",
        "Github Repo: [https://github.com/DLR-RM/rl-baselines3-zoo](https://github.com/DLR-RM/rl-baselines3-zoo)\n",
        "\n",
        "Stable-Baselines3 Repo: [https://github.com/DLR-RM/rl-baselines3-zoo](https://github.com/DLR-RM/stable-baselines3)\n",
        "\n",
        "\n",
        "# Install Dependencies\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "AXVDDlTn02M9"
      },
      "outputs": [],
      "source": [
        "!apt-get install swig cmake ffmpeg freeglut3-dev xvfb"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "f1s9_3b1Tsq9"
      },
      "outputs": [],
      "source": [
        "!apt-get install -y \\\n",
        "    libgl1-mesa-dev \\\n",
        "    libgl1-mesa-glx \\\n",
        "    libglew-dev \\\n",
        "    libosmesa6-dev \\\n",
        "    software-properties-common\n",
        "\n",
        "!apt-get install -y patchelf"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "kDjF3qRg7oGH"
      },
      "source": [
        "## Clone RL Baselines3 Zoo Repo"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "SCjGikdT1DFy"
      },
      "outputs": [],
      "source": [
        "!git clone --recursive https://github.com/DLR-RM/rl-baselines3-zoo"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "REMQlh-ezyVt"
      },
      "outputs": [],
      "source": [
        "%cd /content/rl-baselines3-zoo/"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "5tmD_QTBqTMb"
      },
      "source": [
        "### Install pip dependencies"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "OWIDzgJTqShY"
      },
      "outputs": [],
      "source": [
        "!pip install -r requirements.txt"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "VlOD-cImRMwW"
      },
      "outputs": [],
      "source": [
        "!pip install free-mujoco-py"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "kYPotFDF0Noa"
      },
      "source": [
        "## Pretrained model\n",
        "\n",
        "gym environments: https://gym.openai.com/envs/"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "kgOp2XIklaf2"
      },
      "outputs": [],
      "source": [
        "%cd /content/rl-baselines3-zoo/"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "xVm9QPNVwKXN"
      },
      "source": [
        "### Record  a Video"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "MPyfQxD5z26J"
      },
      "outputs": [],
      "source": [
        "# Set up display; otherwise rendering will fail\n",
        "import os\n",
        "os.system(\"Xvfb :1 -screen 0 1024x768x24 &\")\n",
        "os.environ['DISPLAY'] = ':1'"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ZC3OTfpf8CXu"
      },
      "outputs": [],
      "source": [
        "import base64\n",
        "from pathlib import Path\n",
        "\n",
        "from IPython import display as ipythondisplay\n",
        "\n",
        "def show_videos(video_path='', prefix=''):\n",
        "  \"\"\"\n",
        "  Taken from https://github.com/eleurent/highway-env\n",
        "\n",
        "  :param video_path: (str) Path to the folder containing videos\n",
        "  :param prefix: (str) Filter the video, showing only the only starting with this prefix\n",
        "  \"\"\"\n",
        "  html = []\n",
        "  for mp4 in Path(video_path).glob(\"**/*{}*.mp4\".format(prefix)):\n",
        "      video_b64 = base64.b64encode(mp4.read_bytes())\n",
        "      html.append('''{} <br> <video alt=\"{}\" autoplay \n",
        "                    loop controls style=\"height: 400px;\">\n",
        "                    <source src=\"data:video/mp4;base64,{}\" type=\"video/mp4\" />\n",
        "                </video>'''.format(mp4, mp4, video_b64.decode('ascii')))\n",
        "  ipythondisplay.display(ipythondisplay.HTML(data=\"<br>\".join(html)))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "LW-7EWA50550"
      },
      "source": [
        "### Discrete environments"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ArN0B6YX0m4z"
      },
      "outputs": [],
      "source": [
        "%run scripts/all_plots.py -a dqn qrdqn a2c ppo --env PongNoFrameskip-v4 -f rl-trained-agents/"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "5zmVA6JADKvV"
      },
      "outputs": [],
      "source": [
        "%run scripts/plot_train.py -a dqn -e PongNoFrameskip-v4 -f rl-trained-agents/ -x time"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "BkcI91hABKUm"
      },
      "outputs": [],
      "source": [
        "%run scripts/plot_train.py -a qrdqn -e PongNoFrameskip-v4 -f rl-trained-agents/ -x time"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "E5s7UHn2DmqQ"
      },
      "outputs": [],
      "source": [
        "%run scripts/plot_train.py -a a2c -e PongNoFrameskip-v4 -f rl-trained-agents/ -x time"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "0aP2A-NqCQSC"
      },
      "outputs": [],
      "source": [
        "%run scripts/plot_train.py -a ppo -e PongNoFrameskip-v4 -f rl-trained-agents/ -x time"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "itP336W6HdeC"
      },
      "outputs": [],
      "source": [
        "!python enjoy.py --algo dqn --env PongNoFrameskip-v4 --no-render --n-timesteps 5000"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ip3AauLzwNGP"
      },
      "outputs": [],
      "source": [
        "!python -m utils.record_video --algo dqn --env PongNoFrameskip-v4"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "oKOjFuwK9HI0"
      },
      "outputs": [],
      "source": [
        "show_videos(video_path='rl-trained-agents/dqn', prefix='PongNoFrameskip-v4')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "8kD_C440-xvw"
      },
      "source": [
        "### Continuous environments"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "2yJ0eOyjLM5O"
      },
      "outputs": [],
      "source": [
        "%run scripts/all_plots.py -a ppo trpo sac td3 tqc --env Ant-v3 -f rl-trained-agents/"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "iQAJmndtNCvN"
      },
      "outputs": [],
      "source": [
        "%run scripts/plot_train.py -a ppo -e Ant-v3 -f rl-trained-agents/ -x time"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "O1L_ze67NDf9"
      },
      "outputs": [],
      "source": [
        "%run scripts/plot_train.py -a trpo -e Ant-v3 -f rl-trained-agents/ -x time"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "FPn0iySFNzCQ"
      },
      "outputs": [],
      "source": [
        "%run scripts/plot_train.py -a tqc -e Ant-v3 -f rl-trained-agents/ -x time"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "3vDHeNNTNwsT"
      },
      "outputs": [],
      "source": [
        "%run scripts/plot_train.py -a td3 -e Ant-v3 -f rl-trained-agents/ -x time"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "dQDNhdlgOzFs"
      },
      "outputs": [],
      "source": [
        "%run scripts/plot_train.py -a sac -e Ant-v3 -f rl-trained-agents/ -x time"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "s9aT46m8QXzt"
      },
      "outputs": [],
      "source": [
        "!python enjoy.py --algo td3 --env Ant-v3 --no-render --n-timesteps 5000"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "zx7gZ23UQYQ4"
      },
      "outputs": [],
      "source": [
        "!python -m utils.record_video --algo td3 --env Ant-v3"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "yvETH_M_Unw3"
      },
      "outputs": [],
      "source": [
        "show_videos(video_path='rl-trained-agents/td3', prefix='Ant-v3')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "6gJ-pAbF7zRZ"
      },
      "source": [
        "## Train an RL Agent\n",
        "\n",
        "\n",
        "The train agent can be found in the `logs/` folder.\n",
        "\n",
        "Here we will train A2C on CartPole-v1 environment for 100 000 steps. \n",
        "\n",
        "\n",
        "To train it on Pong (Atari), you just have to pass `--env PongNoFrameskip-v4`\n",
        "\n",
        "Note: You need to update `hyperparams/algo.yml` to support new environments. You can access it in the side panel of Google Colab. (see https://stackoverflow.com/questions/46986398/import-data-into-google-colaboratory)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "background_save": true
        },
        "id": "9bIR_N7R11XI"
      },
      "outputs": [],
      "source": [
        "!python train.py --algo dqn --env PongNoFrameskip-v4 --n-timesteps 1000000"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "-fHBq73665yD"
      },
      "source": [
        "#### Evaluate trained agent\n",
        "\n",
        "\n",
        "You can remove the `--folder logs/` to evaluate pretrained agent."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Bw8YuEgU6bT3"
      },
      "outputs": [],
      "source": [
        "!python enjoy.py --algo dqn --env PongNoFrameskip-v4 --no-render --n-timesteps 5000 --folder logs/"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "w5Il2J0VHPLC"
      },
      "source": [
        "#### Tune Hyperparameters\n",
        "\n",
        "We use [Optuna](https://optuna.org/) for optimizing the hyperparameters.\n",
        "\n",
        "Tune the hyperparameters for PPO, using a tpe sampler and median pruner, 2 parallels jobs,\n",
        "with a budget of 1000 trials and a maximum of 50000 steps"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "w2sC22eGHTH-"
      },
      "outputs": [],
      "source": [
        "#!python train.py --algo dqn --env PongNoFrameskip-v4 -n 5000 -optimize --n-trials 10 --n-jobs 5 --sampler tpe --pruner median"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "qBuUfnzI8DN6"
      },
      "source": [
        "### Display the video"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "RjdpP0HE8D2p"
      },
      "source": [
        "### Continue Training\n",
        "\n",
        "Here, we will continue training of the previous model"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "zgMZQJJF6u1C"
      },
      "outputs": [],
      "source": [
        "#!python train.py --algo dqn --env PongNoFrameskip-v4  --n-timesteps 50000 -i logs/dqn/PongNoFrameskip-v4_1/PongNoFrameskip-v4.zip"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "GSaoyiAE8cVj"
      },
      "outputs": [],
      "source": [
        "#!python enjoy.py --algo dqn --env PongNoFrameskip-v4 --no-render --n-timesteps 1000 --folder logs/"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "SKglp1awG6c6"
      },
      "outputs": [],
      "source": []
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "collapsed_sections": [],
      "name": "FIDLE rl-baselines-zoo.ipynb",
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}